diff --git a/.gitignore b/.gitignore index a4ec12ca6b53f..7ec8d45e12c6b 100644 --- a/.gitignore +++ b/.gitignore @@ -58,3 +58,4 @@ metastore_db/ metastore/ warehouse/ TempStatsStore/ +sql/hive-thriftserver/test_warehouses diff --git a/.rat-excludes b/.rat-excludes index 372bc2587ccc3..bccb043c2bb55 100644 --- a/.rat-excludes +++ b/.rat-excludes @@ -55,3 +55,4 @@ dist/* .*ipr .*iws logs +.*scalastyle-output.xml diff --git a/LICENSE b/LICENSE index 65e1f480d9b14..e9a1153fdc5db 100644 --- a/LICENSE +++ b/LICENSE @@ -272,7 +272,7 @@ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ======================================================================== -For Py4J (python/lib/py4j0.7.egg and files in assembly/lib/net/sf/py4j): +For Py4J (python/lib/py4j-0.8.2.1-src.zip) ======================================================================== Copyright (c) 2009-2011, Barthelemy Dagenais All rights reserved. @@ -532,7 +532,7 @@ The following components are provided under a BSD-style license. See project lin (New BSD license) Protocol Buffer Java API (org.spark-project.protobuf:protobuf-java:2.4.1-shaded - http://code.google.com/p/protobuf) (The BSD License) Fortran to Java ARPACK (net.sourceforge.f2j:arpack_combined_all:0.1 - http://f2j.sourceforge.net) (The BSD License) xmlenc Library (xmlenc:xmlenc:0.52 - http://xmlenc.sourceforge.net) - (The New BSD License) Py4J (net.sf.py4j:py4j:0.8.1 - http://py4j.sourceforge.net/) + (The New BSD License) Py4J (net.sf.py4j:py4j:0.8.2.1 - http://py4j.sourceforge.net/) (Two-clause BSD-style license) JUnit-Interface (com.novocode:junit-interface:0.10 - http://github.com/szeiger/junit-interface/) (ISC/BSD License) jbcrypt (org.mindrot:jbcrypt:0.3m - http://www.mindrot.org/) @@ -549,3 +549,4 @@ The following components are provided under the MIT License. See project link fo (MIT License) pyrolite (org.spark-project:pyrolite:2.0.1 - http://pythonhosted.org/Pyro4/) (MIT License) scopt (com.github.scopt:scopt_2.10:3.2.0 - https://github.com/scopt/scopt) (The MIT License) Mockito (org.mockito:mockito-all:1.8.5 - http://www.mockito.org) + (MIT License) jquery (https://jquery.org/license/) diff --git a/assembly/pom.xml b/assembly/pom.xml index 567a8dd2a0d94..703f15925bc44 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -165,6 +165,16 @@ + + hive-thriftserver + + + org.apache.spark + spark-hive-thriftserver_${scala.binary.version} + ${project.version} + + + spark-ganglia-lgpl diff --git a/bagel/pom.xml b/bagel/pom.xml index 90c4b095bb611..bd51b112e26fa 100644 --- a/bagel/pom.xml +++ b/bagel/pom.xml @@ -28,7 +28,7 @@ org.apache.spark spark-bagel_2.10 - bagel + bagel jar Spark Project Bagel diff --git a/bin/beeline b/bin/beeline new file mode 100755 index 0000000000000..09fe366c609fa --- /dev/null +++ b/bin/beeline @@ -0,0 +1,45 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Figure out where Spark is installed +FWDIR="$(cd `dirname $0`/..; pwd)" + +# Find the java binary +if [ -n "${JAVA_HOME}" ]; then + RUNNER="${JAVA_HOME}/bin/java" +else + if [ `command -v java` ]; then + RUNNER="java" + else + echo "JAVA_HOME is not set" >&2 + exit 1 + fi +fi + +# Compute classpath using external script +classpath_output=$($FWDIR/bin/compute-classpath.sh) +if [[ "$?" != "0" ]]; then + echo "$classpath_output" + exit 1 +else + CLASSPATH=$classpath_output +fi + +CLASS="org.apache.hive.beeline.BeeLine" +exec "$RUNNER" -cp "$CLASSPATH" $CLASS "$@" diff --git a/bin/compute-classpath.sh b/bin/compute-classpath.sh index e81e8c060cb98..16b794a1592e8 100755 --- a/bin/compute-classpath.sh +++ b/bin/compute-classpath.sh @@ -52,6 +52,7 @@ if [ -n "$SPARK_PREPEND_CLASSES" ]; then CLASSPATH="$CLASSPATH:$FWDIR/sql/catalyst/target/scala-$SCALA_VERSION/classes" CLASSPATH="$CLASSPATH:$FWDIR/sql/core/target/scala-$SCALA_VERSION/classes" CLASSPATH="$CLASSPATH:$FWDIR/sql/hive/target/scala-$SCALA_VERSION/classes" + CLASSPATH="$CLASSPATH:$FWDIR/sql/hive-thriftserver/target/scala-$SCALA_VERSION/classes" CLASSPATH="$CLASSPATH:$FWDIR/yarn/stable/target/scala-$SCALA_VERSION/classes" fi diff --git a/bin/pyspark b/bin/pyspark index 69b056fe28f2c..39a20e2a24a3c 100755 --- a/bin/pyspark +++ b/bin/pyspark @@ -52,7 +52,7 @@ export PYSPARK_PYTHON # Add the PySpark classes to the Python path: export PYTHONPATH=$SPARK_HOME/python/:$PYTHONPATH -export PYTHONPATH=$SPARK_HOME/python/lib/py4j-0.8.1-src.zip:$PYTHONPATH +export PYTHONPATH=$SPARK_HOME/python/lib/py4j-0.8.2.1-src.zip:$PYTHONPATH # Load the PySpark shell.py script when ./pyspark is used interactively: export OLD_PYTHONSTARTUP=$PYTHONSTARTUP diff --git a/bin/pyspark2.cmd b/bin/pyspark2.cmd index 0ef9eea95342e..2c4b08af8d4c3 100644 --- a/bin/pyspark2.cmd +++ b/bin/pyspark2.cmd @@ -45,7 +45,7 @@ rem Figure out which Python to use. if [%PYSPARK_PYTHON%] == [] set PYSPARK_PYTHON=python set PYTHONPATH=%FWDIR%python;%PYTHONPATH% -set PYTHONPATH=%FWDIR%python\lib\py4j-0.8.1-src.zip;%PYTHONPATH% +set PYTHONPATH=%FWDIR%python\lib\py4j-0.8.2.1-src.zip;%PYTHONPATH% set OLD_PYTHONSTARTUP=%PYTHONSTARTUP% set PYTHONSTARTUP=%FWDIR%python\pyspark\shell.py diff --git a/bin/run-example b/bin/run-example index 942706d733122..68a35702eddd3 100755 --- a/bin/run-example +++ b/bin/run-example @@ -29,7 +29,8 @@ if [ -n "$1" ]; then else echo "Usage: ./bin/run-example [example-args]" 1>&2 echo " - set MASTER=XX to use a specific master" 1>&2 - echo " - can use abbreviated example class name (e.g. SparkPi, mllib.LinearRegression)" 1>&2 + echo " - can use abbreviated example class name relative to com.apache.spark.examples" 1>&2 + echo " (e.g. SparkPi, mllib.LinearRegression, streaming.KinesisWordCountASL)" 1>&2 exit 1 fi diff --git a/bin/run-example2.cmd b/bin/run-example2.cmd index eadedd7fa61ff..b29bf90c64e90 100644 --- a/bin/run-example2.cmd +++ b/bin/run-example2.cmd @@ -32,7 +32,8 @@ rem Test that an argument was given if not "x%1"=="x" goto arg_given echo Usage: run-example ^ [example-args] echo - set MASTER=XX to use a specific master - echo - can use abbreviated example class name (e.g. SparkPi, mllib.LinearRegression) + echo - can use abbreviated example class name relative to com.apache.spark.examples + echo (e.g. SparkPi, mllib.LinearRegression, streaming.KinesisWordCountASL) goto exit :arg_given diff --git a/bin/spark-shell b/bin/spark-shell index 850e9507ec38f..756c8179d12b6 100755 --- a/bin/spark-shell +++ b/bin/spark-shell @@ -46,11 +46,11 @@ function main(){ # (see https://github.com/sbt/sbt/issues/562). stty -icanon min 1 -echo > /dev/null 2>&1 export SPARK_SUBMIT_OPTS="$SPARK_SUBMIT_OPTS -Djline.terminal=unix" - $FWDIR/bin/spark-submit spark-shell "$@" --class org.apache.spark.repl.Main + $FWDIR/bin/spark-submit --class org.apache.spark.repl.Main spark-shell "$@" stty icanon echo > /dev/null 2>&1 else export SPARK_SUBMIT_OPTS - $FWDIR/bin/spark-submit spark-shell "$@" --class org.apache.spark.repl.Main + $FWDIR/bin/spark-submit --class org.apache.spark.repl.Main spark-shell "$@" fi } diff --git a/bin/spark-shell.cmd b/bin/spark-shell.cmd index 4b9708a8c03f3..b56d69801171c 100755 --- a/bin/spark-shell.cmd +++ b/bin/spark-shell.cmd @@ -19,4 +19,4 @@ rem set SPARK_HOME=%~dp0.. -cmd /V /E /C %SPARK_HOME%\bin\spark-submit.cmd spark-shell %* --class org.apache.spark.repl.Main +cmd /V /E /C %SPARK_HOME%\bin\spark-submit.cmd spark-shell --class org.apache.spark.repl.Main %* diff --git a/bin/spark-sql b/bin/spark-sql new file mode 100755 index 0000000000000..bba7f897b19bc --- /dev/null +++ b/bin/spark-sql @@ -0,0 +1,36 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# +# Shell script for starting the Spark SQL CLI + +# Enter posix mode for bash +set -o posix + +# Figure out where Spark is installed +FWDIR="$(cd `dirname $0`/..; pwd)" + +if [[ "$@" = *--help ]] || [[ "$@" = *-h ]]; then + echo "Usage: ./sbin/spark-sql [options]" + $FWDIR/bin/spark-submit --help 2>&1 | grep -v Usage 1>&2 + exit 0 +fi + +CLASS="org.apache.spark.sql.hive.thriftserver.SparkSQLCLIDriver" +exec "$FWDIR"/bin/spark-submit --class $CLASS spark-internal $@ diff --git a/core/pom.xml b/core/pom.xml index 1054cec4d77bb..6d8be37037729 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -28,7 +28,7 @@ org.apache.spark spark-core_2.10 - core + core jar Spark Project Core @@ -150,7 +150,7 @@ org.json4s json4s-jackson_${scala.binary.version} - 3.2.6 + 3.2.10 colt @@ -192,8 +192,8 @@ org.tachyonproject - tachyon - 0.4.1-thrift + tachyon-client + 0.5.0 org.apache.hadoop @@ -262,6 +262,11 @@ asm test + + junit + junit + test + com.novocode junit-interface @@ -275,7 +280,7 @@ net.sf.py4j py4j - 0.8.1 + 0.8.2.1 diff --git a/core/src/main/scala/org/apache/spark/Accumulators.scala b/core/src/main/scala/org/apache/spark/Accumulators.scala index 9c55bfbb47626..12f2fe031cb1d 100644 --- a/core/src/main/scala/org/apache/spark/Accumulators.scala +++ b/core/src/main/scala/org/apache/spark/Accumulators.scala @@ -36,15 +36,21 @@ import org.apache.spark.serializer.JavaSerializer * * @param initialValue initial value of accumulator * @param param helper object defining how to add elements of type `R` and `T` + * @param name human-readable name for use in Spark's web UI * @tparam R the full accumulated data (result type) * @tparam T partial data that can be added in */ class Accumulable[R, T] ( @transient initialValue: R, - param: AccumulableParam[R, T]) + param: AccumulableParam[R, T], + val name: Option[String]) extends Serializable { - val id = Accumulators.newId + def this(@transient initialValue: R, param: AccumulableParam[R, T]) = + this(initialValue, param, None) + + val id: Long = Accumulators.newId + @transient private var value_ = initialValue // Current value on master val zero = param.zero(initialValue) // Zero value to be passed to workers private var deserialized = false @@ -219,8 +225,10 @@ GrowableAccumulableParam[R <% Growable[T] with TraversableOnce[T] with Serializa * @param param helper object defining how to add elements of type `T` * @tparam T result type */ -class Accumulator[T](@transient initialValue: T, param: AccumulatorParam[T]) - extends Accumulable[T,T](initialValue, param) +class Accumulator[T](@transient initialValue: T, param: AccumulatorParam[T], name: Option[String]) + extends Accumulable[T,T](initialValue, param, name) { + def this(initialValue: T, param: AccumulatorParam[T]) = this(initialValue, param, None) +} /** * A simpler version of [[org.apache.spark.AccumulableParam]] where the only data type you can add @@ -281,4 +289,7 @@ private object Accumulators { } } } + + def stringifyPartialValue(partialValue: Any) = "%s".format(partialValue) + def stringifyValue(value: Any) = "%s".format(value) } diff --git a/core/src/main/scala/org/apache/spark/Aggregator.scala b/core/src/main/scala/org/apache/spark/Aggregator.scala index ff0ca11749d42..79c9c451d273d 100644 --- a/core/src/main/scala/org/apache/spark/Aggregator.scala +++ b/core/src/main/scala/org/apache/spark/Aggregator.scala @@ -56,18 +56,23 @@ case class Aggregator[K, V, C] ( } else { val combiners = new ExternalAppendOnlyMap[K, V, C](createCombiner, mergeValue, mergeCombiners) combiners.insertAll(iter) - // TODO: Make this non optional in a future release - Option(context).foreach(c => c.taskMetrics.memoryBytesSpilled = combiners.memoryBytesSpilled) - Option(context).foreach(c => c.taskMetrics.diskBytesSpilled = combiners.diskBytesSpilled) + // Update task metrics if context is not null + // TODO: Make context non optional in a future release + Option(context).foreach { c => + c.taskMetrics.memoryBytesSpilled += combiners.memoryBytesSpilled + c.taskMetrics.diskBytesSpilled += combiners.diskBytesSpilled + } combiners.iterator } } @deprecated("use combineCombinersByKey with TaskContext argument", "0.9.0") - def combineCombinersByKey(iter: Iterator[(K, C)]) : Iterator[(K, C)] = + def combineCombinersByKey(iter: Iterator[_ <: Product2[K, C]]) : Iterator[(K, C)] = combineCombinersByKey(iter, null) - def combineCombinersByKey(iter: Iterator[(K, C)], context: TaskContext) : Iterator[(K, C)] = { + def combineCombinersByKey(iter: Iterator[_ <: Product2[K, C]], context: TaskContext) + : Iterator[(K, C)] = + { if (!externalSorting) { val combiners = new AppendOnlyMap[K,C] var kc: Product2[K, C] = null @@ -85,9 +90,12 @@ case class Aggregator[K, V, C] ( val pair = iter.next() combiners.insert(pair._1, pair._2) } - // TODO: Make this non optional in a future release - Option(context).foreach(c => c.taskMetrics.memoryBytesSpilled = combiners.memoryBytesSpilled) - Option(context).foreach(c => c.taskMetrics.diskBytesSpilled = combiners.diskBytesSpilled) + // Update task metrics if context is not null + // TODO: Make context non-optional in a future release + Option(context).foreach { c => + c.taskMetrics.memoryBytesSpilled += combiners.memoryBytesSpilled + c.taskMetrics.diskBytesSpilled += combiners.diskBytesSpilled + } combiners.iterator } } diff --git a/core/src/main/scala/org/apache/spark/Dependency.scala b/core/src/main/scala/org/apache/spark/Dependency.scala index f010c03223ef4..ab2594cfc02eb 100644 --- a/core/src/main/scala/org/apache/spark/Dependency.scala +++ b/core/src/main/scala/org/apache/spark/Dependency.scala @@ -19,7 +19,6 @@ package org.apache.spark import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD -import org.apache.spark.rdd.SortOrder.SortOrder import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.ShuffleHandle @@ -28,29 +27,35 @@ import org.apache.spark.shuffle.ShuffleHandle * Base class for dependencies. */ @DeveloperApi -abstract class Dependency[T](val rdd: RDD[T]) extends Serializable +abstract class Dependency[T] extends Serializable { + def rdd: RDD[T] +} /** * :: DeveloperApi :: - * Base class for dependencies where each partition of the parent RDD is used by at most one - * partition of the child RDD. Narrow dependencies allow for pipelined execution. + * Base class for dependencies where each partition of the child RDD depends on a small number + * of partitions of the parent RDD. Narrow dependencies allow for pipelined execution. */ @DeveloperApi -abstract class NarrowDependency[T](rdd: RDD[T]) extends Dependency(rdd) { +abstract class NarrowDependency[T](_rdd: RDD[T]) extends Dependency[T] { /** * Get the parent partitions for a child partition. * @param partitionId a partition of the child RDD * @return the partitions of the parent RDD that the child partition depends upon */ def getParents(partitionId: Int): Seq[Int] + + override def rdd: RDD[T] = _rdd } /** * :: DeveloperApi :: - * Represents a dependency on the output of a shuffle stage. - * @param rdd the parent RDD + * Represents a dependency on the output of a shuffle stage. Note that in the case of shuffle, + * the RDD is transient since we don't need it on the executor side. + * + * @param _rdd the parent RDD * @param partitioner partitioner used to partition the shuffle output * @param serializer [[org.apache.spark.serializer.Serializer Serializer]] to use. If set to None, * the default serializer, as specified by `spark.serializer` config option, will @@ -58,21 +63,22 @@ abstract class NarrowDependency[T](rdd: RDD[T]) extends Dependency(rdd) { */ @DeveloperApi class ShuffleDependency[K, V, C]( - @transient rdd: RDD[_ <: Product2[K, V]], + @transient _rdd: RDD[_ <: Product2[K, V]], val partitioner: Partitioner, val serializer: Option[Serializer] = None, val keyOrdering: Option[Ordering[K]] = None, val aggregator: Option[Aggregator[K, V, C]] = None, - val mapSideCombine: Boolean = false, - val sortOrder: Option[SortOrder] = None) - extends Dependency(rdd.asInstanceOf[RDD[Product2[K, V]]]) { + val mapSideCombine: Boolean = false) + extends Dependency[Product2[K, V]] { + + override def rdd = _rdd.asInstanceOf[RDD[Product2[K, V]]] - val shuffleId: Int = rdd.context.newShuffleId() + val shuffleId: Int = _rdd.context.newShuffleId() - val shuffleHandle: ShuffleHandle = rdd.context.env.shuffleManager.registerShuffle( - shuffleId, rdd.partitions.size, this) + val shuffleHandle: ShuffleHandle = _rdd.context.env.shuffleManager.registerShuffle( + shuffleId, _rdd.partitions.size, this) - rdd.sparkContext.cleaner.foreach(_.registerShuffleForCleanup(this)) + _rdd.sparkContext.cleaner.foreach(_.registerShuffleForCleanup(this)) } diff --git a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala new file mode 100644 index 0000000000000..24ccce21b62ca --- /dev/null +++ b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import akka.actor.Actor +import org.apache.spark.executor.TaskMetrics +import org.apache.spark.storage.BlockManagerId +import org.apache.spark.scheduler.TaskScheduler + +/** + * A heartbeat from executors to the driver. This is a shared message used by several internal + * components to convey liveness or execution information for in-progress tasks. + */ +private[spark] case class Heartbeat( + executorId: String, + taskMetrics: Array[(Long, TaskMetrics)], // taskId -> TaskMetrics + blockManagerId: BlockManagerId) + +private[spark] case class HeartbeatResponse(reregisterBlockManager: Boolean) + +/** + * Lives in the driver to receive heartbeats from executors.. + */ +private[spark] class HeartbeatReceiver(scheduler: TaskScheduler) extends Actor { + override def receive = { + case Heartbeat(executorId, taskMetrics, blockManagerId) => + val response = HeartbeatResponse( + !scheduler.executorHeartbeatReceived(executorId, taskMetrics, blockManagerId)) + sender ! response + } +} diff --git a/core/src/main/scala/org/apache/spark/HttpFileServer.scala b/core/src/main/scala/org/apache/spark/HttpFileServer.scala index 0e3750fdde415..edc3889c9ae51 100644 --- a/core/src/main/scala/org/apache/spark/HttpFileServer.scala +++ b/core/src/main/scala/org/apache/spark/HttpFileServer.scala @@ -23,7 +23,10 @@ import com.google.common.io.Files import org.apache.spark.util.Utils -private[spark] class HttpFileServer(securityManager: SecurityManager) extends Logging { +private[spark] class HttpFileServer( + securityManager: SecurityManager, + requestedPort: Int = 0) + extends Logging { var baseDir : File = null var fileDir : File = null @@ -38,7 +41,7 @@ private[spark] class HttpFileServer(securityManager: SecurityManager) extends Lo fileDir.mkdir() jarDir.mkdir() logInfo("HTTP File server directory is " + baseDir) - httpServer = new HttpServer(baseDir, securityManager) + httpServer = new HttpServer(baseDir, securityManager, requestedPort, "HTTP file server") httpServer.start() serverUri = httpServer.uri logDebug("HTTP file server started at: " + serverUri) diff --git a/core/src/main/scala/org/apache/spark/HttpServer.scala b/core/src/main/scala/org/apache/spark/HttpServer.scala index 7e9b517f901a2..912558d0cab7d 100644 --- a/core/src/main/scala/org/apache/spark/HttpServer.scala +++ b/core/src/main/scala/org/apache/spark/HttpServer.scala @@ -21,7 +21,7 @@ import java.io.File import org.eclipse.jetty.util.security.{Constraint, Password} import org.eclipse.jetty.security.authentication.DigestAuthenticator -import org.eclipse.jetty.security.{ConstraintMapping, ConstraintSecurityHandler, HashLoginService, SecurityHandler} +import org.eclipse.jetty.security.{ConstraintMapping, ConstraintSecurityHandler, HashLoginService} import org.eclipse.jetty.server.Server import org.eclipse.jetty.server.bio.SocketConnector @@ -41,48 +41,68 @@ private[spark] class ServerStateException(message: String) extends Exception(mes * as well as classes created by the interpreter when the user types in code. This is just a wrapper * around a Jetty server. */ -private[spark] class HttpServer(resourceBase: File, securityManager: SecurityManager) - extends Logging { +private[spark] class HttpServer( + resourceBase: File, + securityManager: SecurityManager, + requestedPort: Int = 0, + serverName: String = "HTTP server") + extends Logging { + private var server: Server = null - private var port: Int = -1 + private var port: Int = requestedPort def start() { if (server != null) { throw new ServerStateException("Server is already started") } else { logInfo("Starting HTTP Server") - server = new Server() - val connector = new SocketConnector - connector.setMaxIdleTime(60*1000) - connector.setSoLingerTime(-1) - connector.setPort(0) - server.addConnector(connector) - - val threadPool = new QueuedThreadPool - threadPool.setDaemon(true) - server.setThreadPool(threadPool) - val resHandler = new ResourceHandler - resHandler.setResourceBase(resourceBase.getAbsolutePath) - - val handlerList = new HandlerList - handlerList.setHandlers(Array(resHandler, new DefaultHandler)) - - if (securityManager.isAuthenticationEnabled()) { - logDebug("HttpServer is using security") - val sh = setupSecurityHandler(securityManager) - // make sure we go through security handler to get resources - sh.setHandler(handlerList) - server.setHandler(sh) - } else { - logDebug("HttpServer is not using security") - server.setHandler(handlerList) - } - - server.start() - port = server.getConnectors()(0).getLocalPort() + val (actualServer, actualPort) = + Utils.startServiceOnPort[Server](requestedPort, doStart, serverName) + server = actualServer + port = actualPort } } + /** + * Actually start the HTTP server on the given port. + * + * Note that this is only best effort in the sense that we may end up binding to a nearby port + * in the event of port collision. Return the bound server and the actual port used. + */ + private def doStart(startPort: Int): (Server, Int) = { + val server = new Server() + val connector = new SocketConnector + connector.setMaxIdleTime(60 * 1000) + connector.setSoLingerTime(-1) + connector.setPort(startPort) + server.addConnector(connector) + + val threadPool = new QueuedThreadPool + threadPool.setDaemon(true) + server.setThreadPool(threadPool) + val resHandler = new ResourceHandler + resHandler.setResourceBase(resourceBase.getAbsolutePath) + + val handlerList = new HandlerList + handlerList.setHandlers(Array(resHandler, new DefaultHandler)) + + if (securityManager.isAuthenticationEnabled()) { + logDebug("HttpServer is using security") + val sh = setupSecurityHandler(securityManager) + // make sure we go through security handler to get resources + sh.setHandler(handlerList) + server.setHandler(sh) + } else { + logDebug("HttpServer is not using security") + server.setHandler(handlerList) + } + + server.start() + val actualPort = server.getConnectors()(0).getLocalPort + + (server, actualPort) + } + /** * Setup Jetty to the HashLoginService using a single user with our * shared secret. Configure it to use DIGEST-MD5 authentication so that the password @@ -134,7 +154,7 @@ private[spark] class HttpServer(resourceBase: File, securityManager: SecurityMan if (server == null) { throw new ServerStateException("Server is not started") } else { - return "http://" + Utils.localIpAddress + ":" + port + "http://" + Utils.localIpAddress + ":" + port } } } diff --git a/core/src/main/scala/org/apache/spark/Logging.scala b/core/src/main/scala/org/apache/spark/Logging.scala index 50d8e93e1f0d7..d4f2624061e35 100644 --- a/core/src/main/scala/org/apache/spark/Logging.scala +++ b/core/src/main/scala/org/apache/spark/Logging.scala @@ -39,16 +39,17 @@ trait Logging { // be serialized and used on another machine @transient private var log_ : Logger = null + // Method to get the logger name for this object + protected def logName = { + // Ignore trailing $'s in the class names for Scala objects + this.getClass.getName.stripSuffix("$") + } + // Method to get or create the logger for this object protected def log: Logger = { if (log_ == null) { initializeIfNecessary() - var className = this.getClass.getName - // Ignore trailing $'s in the class names for Scala objects - if (className.endsWith("$")) { - className = className.substring(0, className.length - 1) - } - log_ = LoggerFactory.getLogger(className) + log_ = LoggerFactory.getLogger(logName) } log_ } @@ -110,23 +111,27 @@ trait Logging { } private def initializeLogging() { - // If Log4j is being used, but is not initialized, load a default properties file - val binder = StaticLoggerBinder.getSingleton - val usingLog4j = binder.getLoggerFactoryClassStr.endsWith("Log4jLoggerFactory") - val log4jInitialized = LogManager.getRootLogger.getAllAppenders.hasMoreElements - if (!log4jInitialized && usingLog4j) { + // Don't use a logger in here, as this is itself occurring during initialization of a logger + // If Log4j 1.2 is being used, but is not initialized, load a default properties file + val binderClass = StaticLoggerBinder.getSingleton.getLoggerFactoryClassStr + // This distinguishes the log4j 1.2 binding, currently + // org.slf4j.impl.Log4jLoggerFactory, from the log4j 2.0 binding, currently + // org.apache.logging.slf4j.Log4jLoggerFactory + val usingLog4j12 = "org.slf4j.impl.Log4jLoggerFactory".equals(binderClass) + val log4j12Initialized = LogManager.getRootLogger.getAllAppenders.hasMoreElements + if (!log4j12Initialized && usingLog4j12) { val defaultLogProps = "org/apache/spark/log4j-defaults.properties" Option(Utils.getSparkClassLoader.getResource(defaultLogProps)) match { case Some(url) => PropertyConfigurator.configure(url) - log.info(s"Using Spark's default log4j profile: $defaultLogProps") + System.err.println(s"Using Spark's default log4j profile: $defaultLogProps") case None => System.err.println(s"Spark was unable to load $defaultLogProps") } } Logging.initialized = true - // Force a call into slf4j to initialize it. Avoids this happening from mutliple threads + // Force a call into slf4j to initialize it. Avoids this happening from multiple threads // and triggering this: http://mailman.qos.ch/pipermail/slf4j-dev/2010-April/002956.html log } diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala b/core/src/main/scala/org/apache/spark/Partitioner.scala index 52c018baa5f7b..37053bb6f37ad 100644 --- a/core/src/main/scala/org/apache/spark/Partitioner.scala +++ b/core/src/main/scala/org/apache/spark/Partitioner.scala @@ -19,11 +19,15 @@ package org.apache.spark import java.io.{IOException, ObjectInputStream, ObjectOutputStream} -import scala.reflect.ClassTag +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer +import scala.reflect.{ClassTag, classTag} +import scala.util.hashing.byteswap32 -import org.apache.spark.rdd.RDD +import org.apache.spark.rdd.{PartitionPruningRDD, RDD} import org.apache.spark.serializer.JavaSerializer import org.apache.spark.util.{CollectionsUtils, Utils} +import org.apache.spark.util.random.{XORShiftRandom, SamplingUtils} /** * An object that defines how the elements in a key-value pair RDD are partitioned by key. @@ -103,26 +107,49 @@ class RangePartitioner[K : Ordering : ClassTag, V]( private var ascending: Boolean = true) extends Partitioner { + // We allow partitions = 0, which happens when sorting an empty RDD under the default settings. + require(partitions >= 0, s"Number of partitions cannot be negative but found $partitions.") + private var ordering = implicitly[Ordering[K]] // An array of upper bounds for the first (partitions - 1) partitions private var rangeBounds: Array[K] = { - if (partitions == 1) { - Array() + if (partitions <= 1) { + Array.empty } else { - val rddSize = rdd.count() - val maxSampleSize = partitions * 20.0 - val frac = math.min(maxSampleSize / math.max(rddSize, 1), 1.0) - val rddSample = rdd.sample(false, frac, 1).map(_._1).collect().sorted - if (rddSample.length == 0) { - Array() + // This is the sample size we need to have roughly balanced output partitions, capped at 1M. + val sampleSize = math.min(20.0 * partitions, 1e6) + // Assume the input partitions are roughly balanced and over-sample a little bit. + val sampleSizePerPartition = math.ceil(3.0 * sampleSize / rdd.partitions.size).toInt + val (numItems, sketched) = RangePartitioner.sketch(rdd.map(_._1), sampleSizePerPartition) + if (numItems == 0L) { + Array.empty } else { - val bounds = new Array[K](partitions - 1) - for (i <- 0 until partitions - 1) { - val index = (rddSample.length - 1) * (i + 1) / partitions - bounds(i) = rddSample(index) + // If a partition contains much more than the average number of items, we re-sample from it + // to ensure that enough items are collected from that partition. + val fraction = math.min(sampleSize / math.max(numItems, 1L), 1.0) + val candidates = ArrayBuffer.empty[(K, Float)] + val imbalancedPartitions = mutable.Set.empty[Int] + sketched.foreach { case (idx, n, sample) => + if (fraction * n > sampleSizePerPartition) { + imbalancedPartitions += idx + } else { + // The weight is 1 over the sampling probability. + val weight = (n.toDouble / sample.size).toFloat + for (key <- sample) { + candidates += ((key, weight)) + } + } + } + if (imbalancedPartitions.nonEmpty) { + // Re-sample imbalanced partitions with the desired sampling probability. + val imbalanced = new PartitionPruningRDD(rdd.map(_._1), imbalancedPartitions.contains) + val seed = byteswap32(-rdd.id - 1) + val reSampled = imbalanced.sample(withReplacement = false, fraction, seed).collect() + val weight = (1.0 / fraction).toFloat + candidates ++= reSampled.map(x => (x, weight)) } - bounds + RangePartitioner.determineBounds(candidates, partitions) } } } @@ -212,3 +239,67 @@ class RangePartitioner[K : Ordering : ClassTag, V]( } } } + +private[spark] object RangePartitioner { + + /** + * Sketches the input RDD via reservoir sampling on each partition. + * + * @param rdd the input RDD to sketch + * @param sampleSizePerPartition max sample size per partition + * @return (total number of items, an array of (partitionId, number of items, sample)) + */ + def sketch[K:ClassTag]( + rdd: RDD[K], + sampleSizePerPartition: Int): (Long, Array[(Int, Int, Array[K])]) = { + val shift = rdd.id + // val classTagK = classTag[K] // to avoid serializing the entire partitioner object + val sketched = rdd.mapPartitionsWithIndex { (idx, iter) => + val seed = byteswap32(idx ^ (shift << 16)) + val (sample, n) = SamplingUtils.reservoirSampleAndCount( + iter, sampleSizePerPartition, seed) + Iterator((idx, n, sample)) + }.collect() + val numItems = sketched.map(_._2.toLong).sum + (numItems, sketched) + } + + /** + * Determines the bounds for range partitioning from candidates with weights indicating how many + * items each represents. Usually this is 1 over the probability used to sample this candidate. + * + * @param candidates unordered candidates with weights + * @param partitions number of partitions + * @return selected bounds + */ + def determineBounds[K:Ordering:ClassTag]( + candidates: ArrayBuffer[(K, Float)], + partitions: Int): Array[K] = { + val ordering = implicitly[Ordering[K]] + val ordered = candidates.sortBy(_._1) + val numCandidates = ordered.size + val sumWeights = ordered.map(_._2.toDouble).sum + val step = sumWeights / partitions + var cumWeight = 0.0 + var target = step + val bounds = ArrayBuffer.empty[K] + var i = 0 + var j = 0 + var previousBound = Option.empty[K] + while ((i < numCandidates) && (j < partitions - 1)) { + val (key, weight) = ordered(i) + cumWeight += weight + if (cumWeight > target) { + // Skip duplicate values. + if (previousBound.isEmpty || ordering.gt(key, previousBound.get)) { + bounds += key + target += step + j += 1 + previousBound = Some(key) + } + } + i += 1 + } + bounds.toArray + } +} diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala index 74aa441619bd2..25c2c9fc6af7c 100644 --- a/core/src/main/scala/org/apache/spark/SecurityManager.scala +++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala @@ -41,10 +41,19 @@ import org.apache.spark.deploy.SparkHadoopUtil * secure the UI if it has data that other users should not be allowed to see. The javax * servlet filter specified by the user can authenticate the user and then once the user * is logged in, Spark can compare that user versus the view acls to make sure they are - * authorized to view the UI. The configs 'spark.ui.acls.enable' and 'spark.ui.view.acls' + * authorized to view the UI. The configs 'spark.acls.enable' and 'spark.ui.view.acls' * control the behavior of the acls. Note that the person who started the application * always has view access to the UI. * + * Spark has a set of modify acls (`spark.modify.acls`) that controls which users have permission + * to modify a single application. This would include things like killing the application. By + * default the person who started the application has modify access. For modify access through + * the UI, you must have a filter that does authentication in place for the modify acls to work + * properly. + * + * Spark also has a set of admin acls (`spark.admin.acls`) which is a set of users/administrators + * who always have permission to view or modify the Spark application. + * * Spark does not currently support encryption after authentication. * * At this point spark has multiple communication protocols that need to be secured and @@ -137,18 +146,32 @@ private[spark] class SecurityManager(sparkConf: SparkConf) extends Logging { private val sparkSecretLookupKey = "sparkCookie" private val authOn = sparkConf.getBoolean("spark.authenticate", false) - private var uiAclsOn = sparkConf.getBoolean("spark.ui.acls.enable", false) + // keep spark.ui.acls.enable for backwards compatibility with 1.0 + private var aclsOn = sparkConf.getOption("spark.acls.enable").getOrElse( + sparkConf.get("spark.ui.acls.enable", "false")).toBoolean + + // admin acls should be set before view or modify acls + private var adminAcls: Set[String] = + stringToSet(sparkConf.get("spark.admin.acls", "")) private var viewAcls: Set[String] = _ + + // list of users who have permission to modify the application. This should + // apply to both UI and CLI for things like killing the application. + private var modifyAcls: Set[String] = _ + // always add the current user and SPARK_USER to the viewAcls - private val defaultAclUsers = Seq[String](System.getProperty("user.name", ""), + private val defaultAclUsers = Set[String](System.getProperty("user.name", ""), Option(System.getenv("SPARK_USER")).getOrElse("")) + setViewAcls(defaultAclUsers, sparkConf.get("spark.ui.view.acls", "")) + setModifyAcls(defaultAclUsers, sparkConf.get("spark.modify.acls", "")) private val secretKey = generateSecretKey() logInfo("SecurityManager: authentication " + (if (authOn) "enabled" else "disabled") + - "; ui acls " + (if (uiAclsOn) "enabled" else "disabled") + - "; users with view permissions: " + viewAcls.toString()) + "; ui acls " + (if (aclsOn) "enabled" else "disabled") + + "; users with view permissions: " + viewAcls.toString() + + "; users with modify permissions: " + modifyAcls.toString()) // Set our own authenticator to properly negotiate user/password for HTTP connections. // This is needed by the HTTP client fetching from the HttpServer. Put here so its @@ -169,18 +192,51 @@ private[spark] class SecurityManager(sparkConf: SparkConf) extends Logging { ) } - private[spark] def setViewAcls(defaultUsers: Seq[String], allowedUsers: String) { - viewAcls = (defaultUsers ++ allowedUsers.split(',')).map(_.trim()).filter(!_.isEmpty).toSet + /** + * Split a comma separated String, filter out any empty items, and return a Set of strings + */ + private def stringToSet(list: String): Set[String] = { + list.split(',').map(_.trim).filter(!_.isEmpty).toSet + } + + /** + * Admin acls should be set before the view or modify acls. If you modify the admin + * acls you should also set the view and modify acls again to pick up the changes. + */ + def setViewAcls(defaultUsers: Set[String], allowedUsers: String) { + viewAcls = (adminAcls ++ defaultUsers ++ stringToSet(allowedUsers)) logInfo("Changing view acls to: " + viewAcls.mkString(",")) } - private[spark] def setViewAcls(defaultUser: String, allowedUsers: String) { - setViewAcls(Seq[String](defaultUser), allowedUsers) + def setViewAcls(defaultUser: String, allowedUsers: String) { + setViewAcls(Set[String](defaultUser), allowedUsers) + } + + def getViewAcls: String = viewAcls.mkString(",") + + /** + * Admin acls should be set before the view or modify acls. If you modify the admin + * acls you should also set the view and modify acls again to pick up the changes. + */ + def setModifyAcls(defaultUsers: Set[String], allowedUsers: String) { + modifyAcls = (adminAcls ++ defaultUsers ++ stringToSet(allowedUsers)) + logInfo("Changing modify acls to: " + modifyAcls.mkString(",")) + } + + def getModifyAcls: String = modifyAcls.mkString(",") + + /** + * Admin acls should be set before the view or modify acls. If you modify the admin + * acls you should also set the view and modify acls again to pick up the changes. + */ + def setAdminAcls(adminUsers: String) { + adminAcls = stringToSet(adminUsers) + logInfo("Changing admin acls to: " + adminAcls.mkString(",")) } - private[spark] def setUIAcls(aclSetting: Boolean) { - uiAclsOn = aclSetting - logInfo("Changing acls enabled to: " + uiAclsOn) + def setAcls(aclSetting: Boolean) { + aclsOn = aclSetting + logInfo("Changing acls enabled to: " + aclsOn) } /** @@ -224,22 +280,39 @@ private[spark] class SecurityManager(sparkConf: SparkConf) extends Logging { * Check to see if Acls for the UI are enabled * @return true if UI authentication is enabled, otherwise false */ - def uiAclsEnabled(): Boolean = uiAclsOn + def aclsEnabled(): Boolean = aclsOn /** * Checks the given user against the view acl list to see if they have - * authorization to view the UI. If the UI acls must are disabled - * via spark.ui.acls.enable, all users have view access. + * authorization to view the UI. If the UI acls are disabled + * via spark.acls.enable, all users have view access. If the user is null + * it is assumed authentication is off and all users have access. * * @param user to see if is authorized * @return true is the user has permission, otherwise false */ def checkUIViewPermissions(user: String): Boolean = { - logDebug("user=" + user + " uiAclsEnabled=" + uiAclsEnabled() + " viewAcls=" + + logDebug("user=" + user + " aclsEnabled=" + aclsEnabled() + " viewAcls=" + viewAcls.mkString(",")) - if (uiAclsEnabled() && (user != null) && (!viewAcls.contains(user))) false else true + if (aclsEnabled() && (user != null) && (!viewAcls.contains(user))) false else true } + /** + * Checks the given user against the modify acl list to see if they have + * authorization to modify the application. If the UI acls are disabled + * via spark.acls.enable, all users have modify access. If the user is null + * it is assumed authentication isn't turned on and all users have access. + * + * @param user to see if is authorized + * @return true is the user has permission, otherwise false + */ + def checkModifyPermissions(user: String): Boolean = { + logDebug("user=" + user + " aclsEnabled=" + aclsEnabled() + " modifyAcls=" + + modifyAcls.mkString(",")) + if (aclsEnabled() && (user != null) && (!modifyAcls.contains(user))) false else true + } + + /** * Check to see if authentication for the Spark communication protocols is enabled * @return true if authentication is enabled, otherwise false diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index 8ce4b91cae8ae..13f0bff7ee507 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -40,6 +40,8 @@ import scala.collection.mutable.HashMap */ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { + import SparkConf._ + /** Create a SparkConf that loads defaults from system properties and the classpath */ def this() = this(true) @@ -198,7 +200,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { * * E.g. spark.akka.option.x.y.x = "value" */ - getAll.filter {case (k, v) => k.startsWith("akka.")} + getAll.filter { case (k, _) => isAkkaConf(k) } /** Does the configuration contain a given parameter? */ def contains(key: String): Boolean = settings.contains(key) @@ -236,6 +238,20 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { } } + // Validate memory fractions + val memoryKeys = Seq( + "spark.storage.memoryFraction", + "spark.shuffle.memoryFraction", + "spark.shuffle.safetyFraction", + "spark.storage.unrollFraction", + "spark.storage.safetyFraction") + for (key <- memoryKeys) { + val value = getDouble(key, 0.5) + if (value > 1 || value < 0) { + throw new IllegalArgumentException("$key should be between 0 and 1 (was '$value').") + } + } + // Check for legacy configs sys.env.get("SPARK_JAVA_OPTS").foreach { value => val warning = @@ -292,3 +308,29 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { settings.toArray.sorted.map{case (k, v) => k + "=" + v}.mkString("\n") } } + +private[spark] object SparkConf { + /** + * Return whether the given config is an akka config (e.g. akka.actor.provider). + * Note that this does not include spark-specific akka configs (e.g. spark.akka.timeout). + */ + def isAkkaConf(name: String): Boolean = name.startsWith("akka.") + + /** + * Return whether the given config should be passed to an executor on start-up. + * + * Certain akka and authentication configs are required of the executor when it connects to + * the scheduler, while the rest of the spark configs can be inherited from the driver later. + */ + def isExecutorStartupConf(name: String): Boolean = { + isAkkaConf(name) || + name.startsWith("spark.akka") || + name.startsWith("spark.auth") || + isSparkPortConf(name) + } + + /** + * Return whether the given config is a Spark port config. + */ + def isSparkPortConf(name: String): Boolean = name.startsWith("spark.") && name.endsWith(".port") +} diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 3e6addeaf04a8..e132955f0f850 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -36,6 +36,7 @@ import org.apache.hadoop.mapred.{FileInputFormat, InputFormat, JobConf, Sequence import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat, Job => NewHadoopJob} import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFormat} import org.apache.mesos.MesosNativeLibrary +import akka.actor.Props import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.broadcast.Broadcast @@ -47,7 +48,7 @@ import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend, SparkDeploySchedulerBackend, SimrSchedulerBackend} import org.apache.spark.scheduler.cluster.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend} import org.apache.spark.scheduler.local.LocalBackend -import org.apache.spark.storage.{BlockManagerSource, RDDInfo, StorageStatus, StorageUtils} +import org.apache.spark.storage._ import org.apache.spark.ui.SparkUI import org.apache.spark.util.{CallSite, ClosureCleaner, MetadataCleaner, MetadataCleanerType, TimeStampedWeakValueHashMap, Utils} @@ -289,7 +290,7 @@ class SparkContext(config: SparkConf) extends Logging { value <- Option(System.getenv(envKey)).orElse(Option(System.getProperty(propKey)))} { executorEnvs(envKey) = value } - Option(System.getenv("SPARK_PREPEND_CLASSES")).foreach { v => + Option(System.getenv("SPARK_PREPEND_CLASSES")).foreach { v => executorEnvs("SPARK_PREPEND_CLASSES") = v } // The Mesos scheduler backend relies on this environment variable to set executor memory. @@ -307,6 +308,8 @@ class SparkContext(config: SparkConf) extends Logging { // Create and start the scheduler private[spark] var taskScheduler = SparkContext.createTaskScheduler(this, master) + private val heartbeatReceiver = env.actorSystem.actorOf( + Props(new HeartbeatReceiver(taskScheduler)), "HeartbeatReceiver") @volatile private[spark] var dagScheduler: DAGScheduler = _ try { dagScheduler = new DAGScheduler(this) @@ -455,7 +458,7 @@ class SparkContext(config: SparkConf) extends Logging { /** Distribute a local Scala collection to form an RDD, with one or more * location preferences (hostnames of Spark nodes) for each object. * Create a new partition for each collection item. */ - def makeRDD[T: ClassTag](seq: Seq[(T, Seq[String])]): RDD[T] = { + def makeRDD[T: ClassTag](seq: Seq[(T, Seq[String])]): RDD[T] = { val indexToPrefs = seq.zipWithIndex.map(t => (t._2, t._1._2)).toMap new ParallelCollectionRDD[T](this, seq.map(_._1), seq.size, indexToPrefs) } @@ -757,6 +760,15 @@ class SparkContext(config: SparkConf) extends Logging { def accumulator[T](initialValue: T)(implicit param: AccumulatorParam[T]) = new Accumulator(initialValue, param) + /** + * Create an [[org.apache.spark.Accumulator]] variable of a given type, with a name for display + * in the Spark UI. Tasks can "add" values to the accumulator using the `+=` method. Only the + * driver can access the accumulator's `value`. + */ + def accumulator[T](initialValue: T, name: String)(implicit param: AccumulatorParam[T]) = { + new Accumulator(initialValue, param, Some(name)) + } + /** * Create an [[org.apache.spark.Accumulable]] shared variable, to which tasks can add values * with `+=`. Only the driver can access the accumuable's `value`. @@ -766,6 +778,16 @@ class SparkContext(config: SparkConf) extends Logging { def accumulable[T, R](initialValue: T)(implicit param: AccumulableParam[T, R]) = new Accumulable(initialValue, param) + /** + * Create an [[org.apache.spark.Accumulable]] shared variable, with a name for display in the + * Spark UI. Tasks can add values to the accumuable using the `+=` operator. Only the driver can + * access the accumuable's `value`. + * @tparam T accumulator type + * @tparam R type that can be added to the accumulator + */ + def accumulable[T, R](initialValue: T, name: String)(implicit param: AccumulableParam[T, R]) = + new Accumulable(initialValue, param, Some(name)) + /** * Create an accumulator from a "mutable collection" type. * @@ -840,7 +862,9 @@ class SparkContext(config: SparkConf) extends Logging { */ @DeveloperApi def getRDDStorageInfo: Array[RDDInfo] = { - StorageUtils.rddInfoFromStorageStatus(getExecutorStorageStatus, this) + val rddInfos = persistentRdds.values.map(RDDInfo.fromRdd).toArray + StorageUtils.updateRddInfo(rddInfos, getExecutorStorageStatus) + rddInfos.filter(_.isCached) } /** @@ -990,15 +1014,15 @@ class SparkContext(config: SparkConf) extends Logging { val dagSchedulerCopy = dagScheduler dagScheduler = null if (dagSchedulerCopy != null) { + env.metricsSystem.report() metadataCleaner.cancel() + env.actorSystem.stop(heartbeatReceiver) cleaner.foreach(_.stop()) dagSchedulerCopy.stop() taskScheduler = null // TODO: Cache.stop()? env.stop() SparkEnv.set(null) - ShuffleMapTask.clearCache() - ResultTask.clearCache() listenerBus.stop() eventLogger.foreach(_.stop()) logInfo("Successfully stopped SparkContext") @@ -1205,10 +1229,10 @@ class SparkContext(config: SparkConf) extends Logging { /** * Clean a closure to make it ready to serialized and send to tasks * (removes unreferenced variables in $outer's, updates REPL variables) - * If checkSerializable is set, clean will also proactively - * check to see if f is serializable and throw a SparkException + * If checkSerializable is set, clean will also proactively + * check to see if f is serializable and throw a SparkException * if not. - * + * * @param f the closure to clean * @param checkSerializable whether or not to immediately check f for serializability * @throws SparkException if checkSerializable is set but f is not @@ -1453,9 +1477,9 @@ object SparkContext extends Logging { /** Creates a task scheduler based on a given master URL. Extracted for testing. */ private def createTaskScheduler(sc: SparkContext, master: String): TaskScheduler = { // Regular expression used for local[N] and local[*] master formats - val LOCAL_N_REGEX = """local\[([0-9\*]+)\]""".r + val LOCAL_N_REGEX = """local\[([0-9]+|\*)\]""".r // Regular expression for local[N, maxRetries], used in tests with failing tasks - val LOCAL_N_FAILURES_REGEX = """local\[([0-9]+)\s*,\s*([0-9]+)\]""".r + val LOCAL_N_FAILURES_REGEX = """local\[([0-9]+|\*)\s*,\s*([0-9]+)\]""".r // Regular expression for simulating a Spark cluster of [N, cores, memory] locally val LOCAL_CLUSTER_REGEX = """local-cluster\[\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*]""".r // Regular expression for connecting to Spark deploy clusters @@ -1485,8 +1509,12 @@ object SparkContext extends Logging { scheduler case LOCAL_N_FAILURES_REGEX(threads, maxFailures) => + def localCpuCount = Runtime.getRuntime.availableProcessors() + // local[*, M] means the number of cores on the computer with M failures + // local[N, M] means exactly N threads with M failures + val threadCount = if (threads == "*") localCpuCount else threads.toInt val scheduler = new TaskSchedulerImpl(sc, maxFailures.toInt, isLocal = true) - val backend = new LocalBackend(scheduler, threads.toInt) + val backend = new LocalBackend(scheduler, threadCount) scheduler.initialize(backend) scheduler diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 6ee731b22c03c..9d4edeb6d96cf 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -18,10 +18,10 @@ package org.apache.spark import java.io.File +import java.net.Socket import scala.collection.JavaConversions._ import scala.collection.mutable -import scala.concurrent.Await import scala.util.Properties import akka.actor._ @@ -34,7 +34,7 @@ import org.apache.spark.metrics.MetricsSystem import org.apache.spark.network.ConnectionManager import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.Serializer -import org.apache.spark.shuffle.ShuffleManager +import org.apache.spark.shuffle.{ShuffleMemoryManager, ShuffleManager} import org.apache.spark.storage._ import org.apache.spark.util.{AkkaUtils, Utils} @@ -65,12 +65,9 @@ class SparkEnv ( val httpFileServer: HttpFileServer, val sparkFilesDir: String, val metricsSystem: MetricsSystem, + val shuffleMemoryManager: ShuffleMemoryManager, val conf: SparkConf) extends Logging { - // A mapping of thread ID to amount of memory, in bytes, used for shuffle aggregations - // All accesses should be manually synchronized - val shuffleMemoryMap = mutable.HashMap[Long, Long]() - private val pythonWorkers = mutable.HashMap[(String, Map[String, String]), PythonWorkerFactory]() // A general, soft-reference map for metadata needed during HadoopRDD split computation @@ -102,10 +99,10 @@ class SparkEnv ( } private[spark] - def destroyPythonWorker(pythonExec: String, envVars: Map[String, String]) { + def destroyPythonWorker(pythonExec: String, envVars: Map[String, String], worker: Socket) { synchronized { val key = (pythonExec, envVars) - pythonWorkers(key).stop() + pythonWorkers.get(key).foreach(_.stopWorker(worker)) } } } @@ -153,10 +150,10 @@ object SparkEnv extends Logging { val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, port, conf = conf, securityManager = securityManager) - // Bit of a hack: If this is the driver and our port was 0 (meaning bind to any free port), - // figure out which port number Akka actually bound to and set spark.driver.port to it. - if (isDriver && port == 0) { - conf.set("spark.driver.port", boundPort.toString) + // Figure out which port Akka actually bound to in case the original port is 0 or occupied. + // This is so that we tell the executors the correct port to connect to. + if (isDriver) { + conf.set("spark.driver.port", boundPort.toString) } // Create an instance of the class named by the given Java system property, or by @@ -193,13 +190,7 @@ object SparkEnv extends Logging { logInfo("Registering " + name) actorSystem.actorOf(Props(newActor), name = name) } else { - val driverHost: String = conf.get("spark.driver.host", "localhost") - val driverPort: Int = conf.getInt("spark.driver.port", 7077) - Utils.checkHost(driverHost, "Expected hostname") - val url = s"akka.tcp://spark@$driverHost:$driverPort/user/$name" - val timeout = AkkaUtils.lookupTimeout(conf) - logInfo(s"Connecting to $name: $url") - Await.result(actorSystem.actorSelection(url).resolveOne(timeout), timeout) + AkkaUtils.makeDriverRef(name, conf, actorSystem) } } @@ -230,7 +221,8 @@ object SparkEnv extends Logging { val httpFileServer = if (isDriver) { - val server = new HttpFileServer(securityManager) + val fileServerPort = conf.getInt("spark.fileserver.port", 0) + val server = new HttpFileServer(securityManager, fileServerPort) server.initialize() conf.set("spark.fileserver.uri", server.serverUri) server @@ -257,6 +249,8 @@ object SparkEnv extends Logging { val shuffleManager = instantiateClass[ShuffleManager]( "spark.shuffle.manager", "org.apache.spark.shuffle.hash.HashShuffleManager") + val shuffleMemoryManager = new ShuffleMemoryManager(conf) + // Warn about deprecated spark.cache.class property if (conf.contains("spark.cache.class")) { logWarning("The spark.cache.class property is no longer being used! Specify storage " + @@ -278,6 +272,7 @@ object SparkEnv extends Logging { httpFileServer, sparkFilesDir, metricsSystem, + shuffleMemoryManager, conf) } diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaHadoopRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaHadoopRDD.scala new file mode 100644 index 0000000000000..0ae0b4ec042e2 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/api/java/JavaHadoopRDD.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java + +import scala.collection.JavaConversions._ +import scala.reflect.ClassTag + +import org.apache.hadoop.mapred.InputSplit + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.api.java.JavaSparkContext._ +import org.apache.spark.api.java.function.{Function2 => JFunction2} +import org.apache.spark.rdd.HadoopRDD + +@DeveloperApi +class JavaHadoopRDD[K, V](rdd: HadoopRDD[K, V]) + (implicit override val kClassTag: ClassTag[K], implicit override val vClassTag: ClassTag[V]) + extends JavaPairRDD[K, V](rdd) { + + /** Maps over a partition, providing the InputSplit that was used as the base of the partition. */ + @DeveloperApi + def mapPartitionsWithInputSplit[R]( + f: JFunction2[InputSplit, java.util.Iterator[(K, V)], java.util.Iterator[R]], + preservesPartitioning: Boolean = false): JavaRDD[R] = { + new JavaRDD(rdd.mapPartitionsWithInputSplit((a, b) => f.call(a, asJavaIterator(b)), + preservesPartitioning)(fakeClassTag))(fakeClassTag) + } +} diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaNewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaNewHadoopRDD.scala new file mode 100644 index 0000000000000..ec4f3964d75e0 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/api/java/JavaNewHadoopRDD.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java + +import scala.collection.JavaConversions._ +import scala.reflect.ClassTag + +import org.apache.hadoop.mapreduce.InputSplit + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.api.java.JavaSparkContext._ +import org.apache.spark.api.java.function.{Function2 => JFunction2} +import org.apache.spark.rdd.NewHadoopRDD + +@DeveloperApi +class JavaNewHadoopRDD[K, V](rdd: NewHadoopRDD[K, V]) + (implicit override val kClassTag: ClassTag[K], implicit override val vClassTag: ClassTag[V]) + extends JavaPairRDD[K, V](rdd) { + + /** Maps over a partition, providing the InputSplit that was used as the base of the partition. */ + @DeveloperApi + def mapPartitionsWithInputSplit[R]( + f: JFunction2[InputSplit, java.util.Iterator[(K, V)], java.util.Iterator[R]], + preservesPartitioning: Boolean = false): JavaRDD[R] = { + new JavaRDD(rdd.mapPartitionsWithInputSplit((a, b) => f.call(a, asJavaIterator(b)), + preservesPartitioning)(fakeClassTag))(fakeClassTag) + } +} diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala index 4f3081433a542..76d4193e96aea 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala @@ -17,7 +17,7 @@ package org.apache.spark.api.java -import java.util.{Comparator, List => JList} +import java.util.{Comparator, List => JList, Map => JMap} import java.lang.{Iterable => JIterable} import scala.collection.JavaConversions._ @@ -122,13 +122,80 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) */ def sample(withReplacement: Boolean, fraction: Double): JavaPairRDD[K, V] = sample(withReplacement, fraction, Utils.random.nextLong) - + /** * Return a sampled subset of this RDD. */ def sample(withReplacement: Boolean, fraction: Double, seed: Long): JavaPairRDD[K, V] = new JavaPairRDD[K, V](rdd.sample(withReplacement, fraction, seed)) + /** + * Return a subset of this RDD sampled by key (via stratified sampling). + * + * Create a sample of this RDD using variable sampling rates for different keys as specified by + * `fractions`, a key to sampling rate map. + * + * If `exact` is set to false, create the sample via simple random sampling, with one pass + * over the RDD, to produce a sample of size that's approximately equal to the sum of + * math.ceil(numItems * samplingRate) over all key values; otherwise, use additional passes over + * the RDD to create a sample size that's exactly equal to the sum of + * math.ceil(numItems * samplingRate) over all key values. + */ + def sampleByKey(withReplacement: Boolean, + fractions: JMap[K, Double], + exact: Boolean, + seed: Long): JavaPairRDD[K, V] = + new JavaPairRDD[K, V](rdd.sampleByKey(withReplacement, fractions, exact, seed)) + + /** + * Return a subset of this RDD sampled by key (via stratified sampling). + * + * Create a sample of this RDD using variable sampling rates for different keys as specified by + * `fractions`, a key to sampling rate map. + * + * If `exact` is set to false, create the sample via simple random sampling, with one pass + * over the RDD, to produce a sample of size that's approximately equal to the sum of + * math.ceil(numItems * samplingRate) over all key values; otherwise, use additional passes over + * the RDD to create a sample size that's exactly equal to the sum of + * math.ceil(numItems * samplingRate) over all key values. + * + * Use Utils.random.nextLong as the default seed for the random number generator + */ + def sampleByKey(withReplacement: Boolean, + fractions: JMap[K, Double], + exact: Boolean): JavaPairRDD[K, V] = + sampleByKey(withReplacement, fractions, exact, Utils.random.nextLong) + + /** + * Return a subset of this RDD sampled by key (via stratified sampling). + * + * Create a sample of this RDD using variable sampling rates for different keys as specified by + * `fractions`, a key to sampling rate map. + * + * Produce a sample of size that's approximately equal to the sum of + * math.ceil(numItems * samplingRate) over all key values with one pass over the RDD via + * simple random sampling. + */ + def sampleByKey(withReplacement: Boolean, + fractions: JMap[K, Double], + seed: Long): JavaPairRDD[K, V] = + sampleByKey(withReplacement, fractions, false, seed) + + /** + * Return a subset of this RDD sampled by key (via stratified sampling). + * + * Create a sample of this RDD using variable sampling rates for different keys as specified by + * `fractions`, a key to sampling rate map. + * + * Produce a sample of size that's approximately equal to the sum of + * math.ceil(numItems * samplingRate) over all key values with one pass over the RDD via + * simple random sampling. + * + * Use Utils.random.nextLong as the default seed for the random number generator + */ + def sampleByKey(withReplacement: Boolean, fractions: JMap[K, Double]): JavaPairRDD[K, V] = + sampleByKey(withReplacement, fractions, false, Utils.random.nextLong) + /** * Return the union of this RDD and another one. Any identical elements will appear multiple * times (use `.distinct()` to eliminate them). @@ -716,6 +783,17 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) sortByKey(comp, ascending) } + /** + * Sort the RDD by key, so that each partition contains a sorted range of the elements. Calling + * `collect` or `save` on the resulting RDD will return or output an ordered list of records + * (in the `save` case, they will be written to multiple `part-X` files in the filesystem, in + * order of the keys). + */ + def sortByKey(ascending: Boolean, numPartitions: Int): JavaPairRDD[K, V] = { + val comp = com.google.common.collect.Ordering.natural().asInstanceOf[Comparator[K]] + sortByKey(comp, ascending, numPartitions) + } + /** * Sort the RDD by key, so that each partition contains a sorted range of the elements. Calling * `collect` or `save` on the resulting RDD will return or output an ordered list of records diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala index 8a5f8088a05ca..e0a4815940db3 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala @@ -34,7 +34,7 @@ import org.apache.spark._ import org.apache.spark.SparkContext.{DoubleAccumulatorParam, IntAccumulatorParam} import org.apache.spark.api.java.JavaSparkContext.fakeClassTag import org.apache.spark.broadcast.Broadcast -import org.apache.spark.rdd.{EmptyRDD, RDD} +import org.apache.spark.rdd.{EmptyRDD, HadoopRDD, NewHadoopRDD, RDD} /** * A Java-friendly version of [[org.apache.spark.SparkContext]] that returns @@ -294,7 +294,8 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork ): JavaPairRDD[K, V] = { implicit val ctagK: ClassTag[K] = ClassTag(keyClass) implicit val ctagV: ClassTag[V] = ClassTag(valueClass) - new JavaPairRDD(sc.hadoopRDD(conf, inputFormatClass, keyClass, valueClass, minPartitions)) + val rdd = sc.hadoopRDD(conf, inputFormatClass, keyClass, valueClass, minPartitions) + new JavaHadoopRDD(rdd.asInstanceOf[HadoopRDD[K, V]]) } /** @@ -314,7 +315,8 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork ): JavaPairRDD[K, V] = { implicit val ctagK: ClassTag[K] = ClassTag(keyClass) implicit val ctagV: ClassTag[V] = ClassTag(valueClass) - new JavaPairRDD(sc.hadoopRDD(conf, inputFormatClass, keyClass, valueClass)) + val rdd = sc.hadoopRDD(conf, inputFormatClass, keyClass, valueClass) + new JavaHadoopRDD(rdd.asInstanceOf[HadoopRDD[K, V]]) } /** Get an RDD for a Hadoop file with an arbitrary InputFormat. @@ -333,7 +335,8 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork ): JavaPairRDD[K, V] = { implicit val ctagK: ClassTag[K] = ClassTag(keyClass) implicit val ctagV: ClassTag[V] = ClassTag(valueClass) - new JavaPairRDD(sc.hadoopFile(path, inputFormatClass, keyClass, valueClass, minPartitions)) + val rdd = sc.hadoopFile(path, inputFormatClass, keyClass, valueClass, minPartitions) + new JavaHadoopRDD(rdd.asInstanceOf[HadoopRDD[K, V]]) } /** Get an RDD for a Hadoop file with an arbitrary InputFormat @@ -351,8 +354,8 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork ): JavaPairRDD[K, V] = { implicit val ctagK: ClassTag[K] = ClassTag(keyClass) implicit val ctagV: ClassTag[V] = ClassTag(valueClass) - new JavaPairRDD(sc.hadoopFile(path, - inputFormatClass, keyClass, valueClass)) + val rdd = sc.hadoopFile(path, inputFormatClass, keyClass, valueClass) + new JavaHadoopRDD(rdd.asInstanceOf[HadoopRDD[K, V]]) } /** @@ -372,7 +375,8 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork conf: Configuration): JavaPairRDD[K, V] = { implicit val ctagK: ClassTag[K] = ClassTag(kClass) implicit val ctagV: ClassTag[V] = ClassTag(vClass) - new JavaPairRDD(sc.newAPIHadoopFile(path, fClass, kClass, vClass, conf)) + val rdd = sc.newAPIHadoopFile(path, fClass, kClass, vClass, conf) + new JavaNewHadoopRDD(rdd.asInstanceOf[NewHadoopRDD[K, V]]) } /** @@ -391,7 +395,8 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork vClass: Class[V]): JavaPairRDD[K, V] = { implicit val ctagK: ClassTag[K] = ClassTag(kClass) implicit val ctagV: ClassTag[V] = ClassTag(vClass) - new JavaPairRDD(sc.newAPIHadoopRDD(conf, fClass, kClass, vClass)) + val rdd = sc.newAPIHadoopRDD(conf, fClass, kClass, vClass) + new JavaNewHadoopRDD(rdd.asInstanceOf[NewHadoopRDD[K, V]]) } /** Build the union of two or more RDDs. */ @@ -424,6 +429,16 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork def intAccumulator(initialValue: Int): Accumulator[java.lang.Integer] = sc.accumulator(initialValue)(IntAccumulatorParam).asInstanceOf[Accumulator[java.lang.Integer]] + /** + * Create an [[org.apache.spark.Accumulator]] integer variable, which tasks can "add" values + * to using the `add` method. Only the master can access the accumulator's `value`. + * + * This version supports naming the accumulator for display in Spark's web UI. + */ + def intAccumulator(initialValue: Int, name: String): Accumulator[java.lang.Integer] = + sc.accumulator(initialValue, name)(IntAccumulatorParam) + .asInstanceOf[Accumulator[java.lang.Integer]] + /** * Create an [[org.apache.spark.Accumulator]] double variable, which tasks can "add" values * to using the `add` method. Only the master can access the accumulator's `value`. @@ -431,12 +446,31 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork def doubleAccumulator(initialValue: Double): Accumulator[java.lang.Double] = sc.accumulator(initialValue)(DoubleAccumulatorParam).asInstanceOf[Accumulator[java.lang.Double]] + /** + * Create an [[org.apache.spark.Accumulator]] double variable, which tasks can "add" values + * to using the `add` method. Only the master can access the accumulator's `value`. + * + * This version supports naming the accumulator for display in Spark's web UI. + */ + def doubleAccumulator(initialValue: Double, name: String): Accumulator[java.lang.Double] = + sc.accumulator(initialValue, name)(DoubleAccumulatorParam) + .asInstanceOf[Accumulator[java.lang.Double]] + /** * Create an [[org.apache.spark.Accumulator]] integer variable, which tasks can "add" values * to using the `add` method. Only the master can access the accumulator's `value`. */ def accumulator(initialValue: Int): Accumulator[java.lang.Integer] = intAccumulator(initialValue) + /** + * Create an [[org.apache.spark.Accumulator]] integer variable, which tasks can "add" values + * to using the `add` method. Only the master can access the accumulator's `value`. + * + * This version supports naming the accumulator for display in Spark's web UI. + */ + def accumulator(initialValue: Int, name: String): Accumulator[java.lang.Integer] = + intAccumulator(initialValue, name) + /** * Create an [[org.apache.spark.Accumulator]] double variable, which tasks can "add" values * to using the `add` method. Only the master can access the accumulator's `value`. @@ -444,6 +478,16 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork def accumulator(initialValue: Double): Accumulator[java.lang.Double] = doubleAccumulator(initialValue) + + /** + * Create an [[org.apache.spark.Accumulator]] double variable, which tasks can "add" values + * to using the `add` method. Only the master can access the accumulator's `value`. + * + * This version supports naming the accumulator for display in Spark's web UI. + */ + def accumulator(initialValue: Double, name: String): Accumulator[java.lang.Double] = + doubleAccumulator(initialValue, name) + /** * Create an [[org.apache.spark.Accumulator]] variable of a given type, which tasks can "add" * values to using the `add` method. Only the master can access the accumulator's `value`. @@ -451,6 +495,16 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork def accumulator[T](initialValue: T, accumulatorParam: AccumulatorParam[T]): Accumulator[T] = sc.accumulator(initialValue)(accumulatorParam) + /** + * Create an [[org.apache.spark.Accumulator]] variable of a given type, which tasks can "add" + * values to using the `add` method. Only the master can access the accumulator's `value`. + * + * This version supports naming the accumulator for display in Spark's web UI. + */ + def accumulator[T](initialValue: T, name: String, accumulatorParam: AccumulatorParam[T]) + : Accumulator[T] = + sc.accumulator(initialValue, name)(accumulatorParam) + /** * Create an [[org.apache.spark.Accumulable]] shared variable of the given type, to which tasks * can "add" values with `add`. Only the master can access the accumuable's `value`. @@ -458,6 +512,16 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork def accumulable[T, R](initialValue: T, param: AccumulableParam[T, R]): Accumulable[T, R] = sc.accumulable(initialValue)(param) + /** + * Create an [[org.apache.spark.Accumulable]] shared variable of the given type, to which tasks + * can "add" values with `add`. Only the master can access the accumuable's `value`. + * + * This version supports naming the accumulator for display in Spark's web UI. + */ + def accumulable[T, R](initialValue: T, name: String, param: AccumulableParam[T, R]) + : Accumulable[T, R] = + sc.accumulable(initialValue, name)(param) + /** * Broadcast a read-only variable to the cluster, returning a * [[org.apache.spark.broadcast.Broadcast]] object for reading it in distributed functions. diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala b/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala index adaa1ef6cf9ff..f3b05e1243045 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala @@ -17,8 +17,9 @@ package org.apache.spark.api.python +import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD -import org.apache.spark.Logging +import org.apache.spark.{Logging, SerializableWritable, SparkException} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.io._ import scala.util.{Failure, Success, Try} @@ -31,13 +32,14 @@ import org.apache.spark.annotation.Experimental * transformation code by overriding the convert method. */ @Experimental -trait Converter[T, U] extends Serializable { +trait Converter[T, + U] extends Serializable { def convert(obj: T): U } private[python] object Converter extends Logging { - def getInstance(converterClass: Option[String]): Converter[Any, Any] = { + def getInstance(converterClass: Option[String], + defaultConverter: Converter[Any, Any]): Converter[Any, Any] = { converterClass.map { cc => Try { val c = Class.forName(cc).newInstance().asInstanceOf[Converter[Any, Any]] @@ -49,7 +51,7 @@ private[python] object Converter extends Logging { logError(s"Failed to load converter: $cc") throw err } - }.getOrElse { new DefaultConverter } + }.getOrElse { defaultConverter } } } @@ -57,7 +59,9 @@ private[python] object Converter extends Logging { * A converter that handles conversion of common [[org.apache.hadoop.io.Writable]] objects. * Other objects are passed through without conversion. */ -private[python] class DefaultConverter extends Converter[Any, Any] { +private[python] class WritableToJavaConverter( + conf: Broadcast[SerializableWritable[Configuration]], + batchSize: Int) extends Converter[Any, Any] { /** * Converts a [[org.apache.hadoop.io.Writable]] to the underlying primitive, String or @@ -72,17 +76,30 @@ private[python] class DefaultConverter extends Converter[Any, Any] { case fw: FloatWritable => fw.get() case t: Text => t.toString case bw: BooleanWritable => bw.get() - case byw: BytesWritable => byw.getBytes + case byw: BytesWritable => + val bytes = new Array[Byte](byw.getLength) + System.arraycopy(byw.getBytes(), 0, bytes, 0, byw.getLength) + bytes case n: NullWritable => null - case aw: ArrayWritable => aw.get().map(convertWritable(_)) - case mw: MapWritable => mapAsJavaMap(mw.map { case (k, v) => - (convertWritable(k), convertWritable(v)) - }.toMap) + case aw: ArrayWritable => + // Due to erasure, all arrays appear as Object[] and they get pickled to Python tuples. + // Since we can't determine element types for empty arrays, we will not attempt to + // convert to primitive arrays (which get pickled to Python arrays). Users may want + // write custom converters for arrays if they know the element types a priori. + aw.get().map(convertWritable(_)) + case mw: MapWritable => + val map = new java.util.HashMap[Any, Any]() + mw.foreach { case (k, v) => + map.put(convertWritable(k), convertWritable(v)) + } + map + case w: Writable => + if (batchSize > 1) WritableUtils.clone(w, conf.value.value) else w case other => other } } - def convert(obj: Any): Any = { + override def convert(obj: Any): Any = { obj match { case writable: Writable => convertWritable(writable) @@ -92,6 +109,47 @@ private[python] class DefaultConverter extends Converter[Any, Any] { } } +/** + * A converter that converts common types to [[org.apache.hadoop.io.Writable]]. Note that array + * types are not supported since the user needs to subclass [[org.apache.hadoop.io.ArrayWritable]] + * to set the type properly. See [[org.apache.spark.api.python.DoubleArrayWritable]] and + * [[org.apache.spark.api.python.DoubleArrayToWritableConverter]] for an example. They are used in + * PySpark RDD `saveAsNewAPIHadoopFile` doctest. + */ +private[python] class JavaToWritableConverter extends Converter[Any, Writable] { + + /** + * Converts common data types to [[org.apache.hadoop.io.Writable]]. Note that array types are not + * supported out-of-the-box. + */ + private def convertToWritable(obj: Any): Writable = { + import collection.JavaConversions._ + obj match { + case i: java.lang.Integer => new IntWritable(i) + case d: java.lang.Double => new DoubleWritable(d) + case l: java.lang.Long => new LongWritable(l) + case f: java.lang.Float => new FloatWritable(f) + case s: java.lang.String => new Text(s) + case b: java.lang.Boolean => new BooleanWritable(b) + case aob: Array[Byte] => new BytesWritable(aob) + case null => NullWritable.get() + case map: java.util.Map[_, _] => + val mapWritable = new MapWritable() + map.foreach { case (k, v) => + mapWritable.put(convertToWritable(k), convertToWritable(v)) + } + mapWritable + case other => throw new SparkException( + s"Data of type ${other.getClass.getName} cannot be used") + } + } + + override def convert(obj: Any): Writable = obj match { + case writable: Writable => writable + case other => convertToWritable(other) + } +} + /** Utilities for working with Python objects <-> Hadoop-related objects */ private[python] object PythonHadoopUtil { @@ -118,7 +176,7 @@ private[python] object PythonHadoopUtil { /** * Converts an RDD of key-value pairs, where key and/or value could be instances of - * [[org.apache.hadoop.io.Writable]], into an RDD[(K, V)] + * [[org.apache.hadoop.io.Writable]], into an RDD of base types, or vice versa. */ def convertRDD[K, V](rdd: RDD[(K, V)], keyConverter: Converter[Any, Any], diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index d87783efd2d01..0b5322c6fb965 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -23,15 +23,18 @@ import java.nio.charset.Charset import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, Collections} import scala.collection.JavaConversions._ +import scala.language.existentials import scala.reflect.ClassTag -import scala.util.Try +import scala.util.{Try, Success, Failure} import net.razorvine.pickle.{Pickler, Unpickler} import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.mapred.{InputFormat, JobConf} -import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} +import org.apache.hadoop.io.compress.CompressionCodec +import org.apache.hadoop.mapred.{InputFormat, OutputFormat, JobConf} +import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat, OutputFormat => NewOutputFormat} import org.apache.spark._ +import org.apache.spark.SparkContext._ import org.apache.spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD} import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD @@ -59,8 +62,8 @@ private[spark] class PythonRDD( val env = SparkEnv.get val localdir = env.blockManager.diskBlockManager.localDirs.map( f => f.getPath()).mkString(",") - val worker: Socket = env.createPythonWorker(pythonExec, - envVars.toMap + ("SPARK_LOCAL_DIR" -> localdir)) + envVars += ("SPARK_LOCAL_DIR" -> localdir) // it's also used in monitor thread + val worker: Socket = env.createPythonWorker(pythonExec, envVars.toMap) // Start a thread to feed the process input from our parent's iterator val writerThread = new WriterThread(env, worker, split, context) @@ -238,7 +241,7 @@ private[spark] class PythonRDD( if (!context.completed) { try { logWarning("Incomplete task interrupted: Attempting to kill Python Worker") - env.destroyPythonWorker(pythonExec, envVars.toMap) + env.destroyPythonWorker(pythonExec, envVars.toMap, worker) } catch { case e: Exception => logError("Exception when trying to kill worker", e) @@ -365,19 +368,17 @@ private[spark] object PythonRDD extends Logging { valueClassMaybeNull: String, keyConverterClass: String, valueConverterClass: String, - minSplits: Int) = { + minSplits: Int, + batchSize: Int) = { val keyClass = Option(keyClassMaybeNull).getOrElse("org.apache.hadoop.io.Text") val valueClass = Option(valueClassMaybeNull).getOrElse("org.apache.hadoop.io.Text") - implicit val kcm = ClassTag(Class.forName(keyClass)).asInstanceOf[ClassTag[K]] - implicit val vcm = ClassTag(Class.forName(valueClass)).asInstanceOf[ClassTag[V]] - val kc = kcm.runtimeClass.asInstanceOf[Class[K]] - val vc = vcm.runtimeClass.asInstanceOf[Class[V]] - + val kc = Class.forName(keyClass).asInstanceOf[Class[K]] + val vc = Class.forName(valueClass).asInstanceOf[Class[V]] val rdd = sc.sc.sequenceFile[K, V](path, kc, vc, minSplits) - val keyConverter = Converter.getInstance(Option(keyConverterClass)) - val valueConverter = Converter.getInstance(Option(valueConverterClass)) - val converted = PythonHadoopUtil.convertRDD[K, V](rdd, keyConverter, valueConverter) - JavaRDD.fromRDD(SerDeUtil.rddToPython(converted)) + val confBroadcasted = sc.sc.broadcast(new SerializableWritable(sc.hadoopConfiguration())) + val converted = convertRDD(rdd, keyConverterClass, valueConverterClass, + new WritableToJavaConverter(confBroadcasted, batchSize)) + JavaRDD.fromRDD(SerDeUtil.pairRDDToPython(converted, batchSize)) } /** @@ -394,17 +395,16 @@ private[spark] object PythonRDD extends Logging { valueClass: String, keyConverterClass: String, valueConverterClass: String, - confAsMap: java.util.HashMap[String, String]) = { - val conf = PythonHadoopUtil.mapToConf(confAsMap) - val baseConf = sc.hadoopConfiguration() - val mergedConf = PythonHadoopUtil.mergeConfs(baseConf, conf) + confAsMap: java.util.HashMap[String, String], + batchSize: Int) = { + val mergedConf = getMergedConf(confAsMap, sc.hadoopConfiguration()) val rdd = newAPIHadoopRDDFromClassNames[K, V, F](sc, Some(path), inputFormatClass, keyClass, valueClass, mergedConf) - val keyConverter = Converter.getInstance(Option(keyConverterClass)) - val valueConverter = Converter.getInstance(Option(valueConverterClass)) - val converted = PythonHadoopUtil.convertRDD[K, V](rdd, keyConverter, valueConverter) - JavaRDD.fromRDD(SerDeUtil.rddToPython(converted)) + val confBroadcasted = sc.sc.broadcast(new SerializableWritable(mergedConf)) + val converted = convertRDD(rdd, keyConverterClass, valueConverterClass, + new WritableToJavaConverter(confBroadcasted, batchSize)) + JavaRDD.fromRDD(SerDeUtil.pairRDDToPython(converted, batchSize)) } /** @@ -421,15 +421,16 @@ private[spark] object PythonRDD extends Logging { valueClass: String, keyConverterClass: String, valueConverterClass: String, - confAsMap: java.util.HashMap[String, String]) = { + confAsMap: java.util.HashMap[String, String], + batchSize: Int) = { val conf = PythonHadoopUtil.mapToConf(confAsMap) val rdd = newAPIHadoopRDDFromClassNames[K, V, F](sc, None, inputFormatClass, keyClass, valueClass, conf) - val keyConverter = Converter.getInstance(Option(keyConverterClass)) - val valueConverter = Converter.getInstance(Option(valueConverterClass)) - val converted = PythonHadoopUtil.convertRDD[K, V](rdd, keyConverter, valueConverter) - JavaRDD.fromRDD(SerDeUtil.rddToPython(converted)) + val confBroadcasted = sc.sc.broadcast(new SerializableWritable(conf)) + val converted = convertRDD(rdd, keyConverterClass, valueConverterClass, + new WritableToJavaConverter(confBroadcasted, batchSize)) + JavaRDD.fromRDD(SerDeUtil.pairRDDToPython(converted, batchSize)) } private def newAPIHadoopRDDFromClassNames[K, V, F <: NewInputFormat[K, V]]( @@ -439,18 +440,14 @@ private[spark] object PythonRDD extends Logging { keyClass: String, valueClass: String, conf: Configuration) = { - implicit val kcm = ClassTag(Class.forName(keyClass)).asInstanceOf[ClassTag[K]] - implicit val vcm = ClassTag(Class.forName(valueClass)).asInstanceOf[ClassTag[V]] - implicit val fcm = ClassTag(Class.forName(inputFormatClass)).asInstanceOf[ClassTag[F]] - val kc = kcm.runtimeClass.asInstanceOf[Class[K]] - val vc = vcm.runtimeClass.asInstanceOf[Class[V]] - val fc = fcm.runtimeClass.asInstanceOf[Class[F]] - val rdd = if (path.isDefined) { + val kc = Class.forName(keyClass).asInstanceOf[Class[K]] + val vc = Class.forName(valueClass).asInstanceOf[Class[V]] + val fc = Class.forName(inputFormatClass).asInstanceOf[Class[F]] + if (path.isDefined) { sc.sc.newAPIHadoopFile[K, V, F](path.get, fc, kc, vc, conf) } else { sc.sc.newAPIHadoopRDD[K, V, F](conf, fc, kc, vc) } - rdd } /** @@ -467,17 +464,16 @@ private[spark] object PythonRDD extends Logging { valueClass: String, keyConverterClass: String, valueConverterClass: String, - confAsMap: java.util.HashMap[String, String]) = { - val conf = PythonHadoopUtil.mapToConf(confAsMap) - val baseConf = sc.hadoopConfiguration() - val mergedConf = PythonHadoopUtil.mergeConfs(baseConf, conf) + confAsMap: java.util.HashMap[String, String], + batchSize: Int) = { + val mergedConf = getMergedConf(confAsMap, sc.hadoopConfiguration()) val rdd = hadoopRDDFromClassNames[K, V, F](sc, Some(path), inputFormatClass, keyClass, valueClass, mergedConf) - val keyConverter = Converter.getInstance(Option(keyConverterClass)) - val valueConverter = Converter.getInstance(Option(valueConverterClass)) - val converted = PythonHadoopUtil.convertRDD[K, V](rdd, keyConverter, valueConverter) - JavaRDD.fromRDD(SerDeUtil.rddToPython(converted)) + val confBroadcasted = sc.sc.broadcast(new SerializableWritable(mergedConf)) + val converted = convertRDD(rdd, keyConverterClass, valueConverterClass, + new WritableToJavaConverter(confBroadcasted, batchSize)) + JavaRDD.fromRDD(SerDeUtil.pairRDDToPython(converted, batchSize)) } /** @@ -494,15 +490,16 @@ private[spark] object PythonRDD extends Logging { valueClass: String, keyConverterClass: String, valueConverterClass: String, - confAsMap: java.util.HashMap[String, String]) = { + confAsMap: java.util.HashMap[String, String], + batchSize: Int) = { val conf = PythonHadoopUtil.mapToConf(confAsMap) val rdd = hadoopRDDFromClassNames[K, V, F](sc, None, inputFormatClass, keyClass, valueClass, conf) - val keyConverter = Converter.getInstance(Option(keyConverterClass)) - val valueConverter = Converter.getInstance(Option(valueConverterClass)) - val converted = PythonHadoopUtil.convertRDD[K, V](rdd, keyConverter, valueConverter) - JavaRDD.fromRDD(SerDeUtil.rddToPython(converted)) + val confBroadcasted = sc.sc.broadcast(new SerializableWritable(conf)) + val converted = convertRDD(rdd, keyConverterClass, valueConverterClass, + new WritableToJavaConverter(confBroadcasted, batchSize)) + JavaRDD.fromRDD(SerDeUtil.pairRDDToPython(converted, batchSize)) } private def hadoopRDDFromClassNames[K, V, F <: InputFormat[K, V]]( @@ -512,18 +509,14 @@ private[spark] object PythonRDD extends Logging { keyClass: String, valueClass: String, conf: Configuration) = { - implicit val kcm = ClassTag(Class.forName(keyClass)).asInstanceOf[ClassTag[K]] - implicit val vcm = ClassTag(Class.forName(valueClass)).asInstanceOf[ClassTag[V]] - implicit val fcm = ClassTag(Class.forName(inputFormatClass)).asInstanceOf[ClassTag[F]] - val kc = kcm.runtimeClass.asInstanceOf[Class[K]] - val vc = vcm.runtimeClass.asInstanceOf[Class[V]] - val fc = fcm.runtimeClass.asInstanceOf[Class[F]] - val rdd = if (path.isDefined) { + val kc = Class.forName(keyClass).asInstanceOf[Class[K]] + val vc = Class.forName(valueClass).asInstanceOf[Class[V]] + val fc = Class.forName(inputFormatClass).asInstanceOf[Class[F]] + if (path.isDefined) { sc.sc.hadoopFile(path.get, fc, kc, vc) } else { sc.sc.hadoopRDD(new JobConf(conf), fc, kc, vc) } - rdd } def writeUTF(str: String, dataOut: DataOutputStream) { @@ -543,24 +536,199 @@ private[spark] object PythonRDD extends Logging { file.close() } + private def getMergedConf(confAsMap: java.util.HashMap[String, String], + baseConf: Configuration): Configuration = { + val conf = PythonHadoopUtil.mapToConf(confAsMap) + PythonHadoopUtil.mergeConfs(baseConf, conf) + } + + private def inferKeyValueTypes[K, V](rdd: RDD[(K, V)], keyConverterClass: String = null, + valueConverterClass: String = null): (Class[_], Class[_]) = { + // Peek at an element to figure out key/value types. Since Writables are not serializable, + // we cannot call first() on the converted RDD. Instead, we call first() on the original RDD + // and then convert locally. + val (key, value) = rdd.first() + val (kc, vc) = getKeyValueConverters(keyConverterClass, valueConverterClass, + new JavaToWritableConverter) + (kc.convert(key).getClass, vc.convert(value).getClass) + } + + private def getKeyValueTypes(keyClass: String, valueClass: String): + Option[(Class[_], Class[_])] = { + for { + k <- Option(keyClass) + v <- Option(valueClass) + } yield (Class.forName(k), Class.forName(v)) + } + + private def getKeyValueConverters(keyConverterClass: String, valueConverterClass: String, + defaultConverter: Converter[Any, Any]): (Converter[Any, Any], Converter[Any, Any]) = { + val keyConverter = Converter.getInstance(Option(keyConverterClass), defaultConverter) + val valueConverter = Converter.getInstance(Option(valueConverterClass), defaultConverter) + (keyConverter, valueConverter) + } + + /** + * Convert an RDD of key-value pairs from internal types to serializable types suitable for + * output, or vice versa. + */ + private def convertRDD[K, V](rdd: RDD[(K, V)], + keyConverterClass: String, + valueConverterClass: String, + defaultConverter: Converter[Any, Any]): RDD[(Any, Any)] = { + val (kc, vc) = getKeyValueConverters(keyConverterClass, valueConverterClass, + defaultConverter) + PythonHadoopUtil.convertRDD(rdd, kc, vc) + } + + /** + * Output a Python RDD of key-value pairs as a Hadoop SequenceFile using the Writable types + * we convert from the RDD's key and value types. Note that keys and values can't be + * [[org.apache.hadoop.io.Writable]] types already, since Writables are not Java + * `Serializable` and we can't peek at them. The `path` can be on any Hadoop file system. + */ + def saveAsSequenceFile[K, V, C <: CompressionCodec]( + pyRDD: JavaRDD[Array[Byte]], + batchSerialized: Boolean, + path: String, + compressionCodecClass: String) = { + saveAsHadoopFile( + pyRDD, batchSerialized, path, "org.apache.hadoop.mapred.SequenceFileOutputFormat", + null, null, null, null, new java.util.HashMap(), compressionCodecClass) + } + + /** + * Output a Python RDD of key-value pairs to any Hadoop file system, using old Hadoop + * `OutputFormat` in mapred package. Keys and values are converted to suitable output + * types using either user specified converters or, if not specified, + * [[org.apache.spark.api.python.JavaToWritableConverter]]. Post-conversion types + * `keyClass` and `valueClass` are automatically inferred if not specified. The passed-in + * `confAsMap` is merged with the default Hadoop conf associated with the SparkContext of + * this RDD. + */ + def saveAsHadoopFile[K, V, F <: OutputFormat[_, _], C <: CompressionCodec]( + pyRDD: JavaRDD[Array[Byte]], + batchSerialized: Boolean, + path: String, + outputFormatClass: String, + keyClass: String, + valueClass: String, + keyConverterClass: String, + valueConverterClass: String, + confAsMap: java.util.HashMap[String, String], + compressionCodecClass: String) = { + val rdd = SerDeUtil.pythonToPairRDD(pyRDD, batchSerialized) + val (kc, vc) = getKeyValueTypes(keyClass, valueClass).getOrElse( + inferKeyValueTypes(rdd, keyConverterClass, valueConverterClass)) + val mergedConf = getMergedConf(confAsMap, pyRDD.context.hadoopConfiguration) + val codec = Option(compressionCodecClass).map(Class.forName(_).asInstanceOf[Class[C]]) + val converted = convertRDD(rdd, keyConverterClass, valueConverterClass, + new JavaToWritableConverter) + val fc = Class.forName(outputFormatClass).asInstanceOf[Class[F]] + converted.saveAsHadoopFile(path, kc, vc, fc, new JobConf(mergedConf), codec=codec) + } + + /** + * Output a Python RDD of key-value pairs to any Hadoop file system, using new Hadoop + * `OutputFormat` in mapreduce package. Keys and values are converted to suitable output + * types using either user specified converters or, if not specified, + * [[org.apache.spark.api.python.JavaToWritableConverter]]. Post-conversion types + * `keyClass` and `valueClass` are automatically inferred if not specified. The passed-in + * `confAsMap` is merged with the default Hadoop conf associated with the SparkContext of + * this RDD. + */ + def saveAsNewAPIHadoopFile[K, V, F <: NewOutputFormat[_, _]]( + pyRDD: JavaRDD[Array[Byte]], + batchSerialized: Boolean, + path: String, + outputFormatClass: String, + keyClass: String, + valueClass: String, + keyConverterClass: String, + valueConverterClass: String, + confAsMap: java.util.HashMap[String, String]) = { + val rdd = SerDeUtil.pythonToPairRDD(pyRDD, batchSerialized) + val (kc, vc) = getKeyValueTypes(keyClass, valueClass).getOrElse( + inferKeyValueTypes(rdd, keyConverterClass, valueConverterClass)) + val mergedConf = getMergedConf(confAsMap, pyRDD.context.hadoopConfiguration) + val converted = convertRDD(rdd, keyConverterClass, valueConverterClass, + new JavaToWritableConverter) + val fc = Class.forName(outputFormatClass).asInstanceOf[Class[F]] + converted.saveAsNewAPIHadoopFile(path, kc, vc, fc, mergedConf) + } + /** - * Convert an RDD of serialized Python dictionaries to Scala Maps - * TODO: Support more Python types. + * Output a Python RDD of key-value pairs to any Hadoop file system, using a Hadoop conf + * converted from the passed-in `confAsMap`. The conf should set relevant output params ( + * e.g., output path, output format, etc), in the same way as it would be configured for + * a Hadoop MapReduce job. Both old and new Hadoop OutputFormat APIs are supported + * (mapred vs. mapreduce). Keys/values are converted for output using either user specified + * converters or, by default, [[org.apache.spark.api.python.JavaToWritableConverter]]. */ + def saveAsHadoopDataset[K, V]( + pyRDD: JavaRDD[Array[Byte]], + batchSerialized: Boolean, + confAsMap: java.util.HashMap[String, String], + keyConverterClass: String, + valueConverterClass: String, + useNewAPI: Boolean) = { + val conf = PythonHadoopUtil.mapToConf(confAsMap) + val converted = convertRDD(SerDeUtil.pythonToPairRDD(pyRDD, batchSerialized), + keyConverterClass, valueConverterClass, new JavaToWritableConverter) + if (useNewAPI) { + converted.saveAsNewAPIHadoopDataset(conf) + } else { + converted.saveAsHadoopDataset(new JobConf(conf)) + } + } + + + /** + * Convert an RDD of serialized Python dictionaries to Scala Maps (no recursive conversions). + */ + @deprecated("PySpark does not use it anymore", "1.1") def pythonToJavaMap(pyRDD: JavaRDD[Array[Byte]]): JavaRDD[Map[String, _]] = { pyRDD.rdd.mapPartitions { iter => val unpickle = new Unpickler - // TODO: Figure out why flatMap is necessay for pyspark iter.flatMap { row => unpickle.loads(row) match { - case objs: java.util.ArrayList[JMap[String, _] @unchecked] => objs.map(_.toMap) - // Incase the partition doesn't have a collection + // in case of objects are pickled in batch mode + case objs: JArrayList[JMap[String, _] @unchecked] => objs.map(_.toMap) + // not in batch mode case obj: JMap[String @unchecked, _] => Seq(obj.toMap) } } } } + /** + * Convert an RDD of serialized Python tuple to Array (no recursive conversions). + * It is only used by pyspark.sql. + */ + def pythonToJavaArray(pyRDD: JavaRDD[Array[Byte]], batched: Boolean): JavaRDD[Array[_]] = { + + def toArray(obj: Any): Array[_] = { + obj match { + case objs: JArrayList[_] => + objs.toArray + case obj if obj.getClass.isArray => + obj.asInstanceOf[Array[_]].toArray + } + } + + pyRDD.rdd.mapPartitions { iter => + val unpickle = new Unpickler + iter.flatMap { row => + val obj = unpickle.loads(row) + if (batched) { + obj.asInstanceOf[JArrayList[_]].map(toArray) + } else { + Seq(toArray(obj)) + } + } + }.toJavaRDD() + } + /** * Convert and RDD of Java objects to and RDD of serialized Python objects, that is usable by * PySpark. @@ -591,19 +759,30 @@ private class PythonAccumulatorParam(@transient serverHost: String, serverPort: val bufferSize = SparkEnv.get.conf.getInt("spark.buffer.size", 65536) + /** + * We try to reuse a single Socket to transfer accumulator updates, as they are all added + * by the DAGScheduler's single-threaded actor anyway. + */ + @transient var socket: Socket = _ + + def openSocket(): Socket = synchronized { + if (socket == null || socket.isClosed) { + socket = new Socket(serverHost, serverPort) + } + socket + } + override def zero(value: JList[Array[Byte]]): JList[Array[Byte]] = new JArrayList override def addInPlace(val1: JList[Array[Byte]], val2: JList[Array[Byte]]) - : JList[Array[Byte]] = { + : JList[Array[Byte]] = synchronized { if (serverHost == null) { // This happens on the worker node, where we just want to remember all the updates val1.addAll(val2) val1 } else { // This happens on the master, where we pass the updates to Python through a socket - val socket = new Socket(serverHost, serverPort) - // SPARK-2282: Immediately reuse closed sockets because we create one per task. - socket.setReuseAddress(true) + val socket = openSocket() val in = socket.getInputStream val out = new DataOutputStream(new BufferedOutputStream(socket.getOutputStream, bufferSize)) out.writeInt(val2.size) @@ -617,7 +796,6 @@ private class PythonAccumulatorParam(@transient serverHost: String, serverPort: if (byteRead == -1) { throw new SparkException("EOF reached before Python server acknowledged") } - socket.close() null } } diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala index 6d3e257c4d5df..52c70712eea3d 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala @@ -29,7 +29,7 @@ private[spark] object PythonUtils { val pythonPath = new ArrayBuffer[String] for (sparkHome <- sys.env.get("SPARK_HOME")) { pythonPath += Seq(sparkHome, "python").mkString(File.separator) - pythonPath += Seq(sparkHome, "python", "lib", "py4j-0.8.1-src.zip").mkString(File.separator) + pythonPath += Seq(sparkHome, "python", "lib", "py4j-0.8.2.1-src.zip").mkString(File.separator) } pythonPath ++= SparkContext.jarOfObject(this) pythonPath.mkString(File.pathSeparator) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala index 759cbe2c46c52..7af260d0b7f26 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala @@ -17,9 +17,11 @@ package org.apache.spark.api.python -import java.io.{DataInputStream, InputStream, OutputStreamWriter} +import java.lang.Runtime +import java.io.{DataOutputStream, DataInputStream, InputStream, OutputStreamWriter} import java.net.{InetAddress, ServerSocket, Socket, SocketException} +import scala.collection.mutable import scala.collection.JavaConversions._ import org.apache.spark._ @@ -39,6 +41,9 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String var daemon: Process = null val daemonHost = InetAddress.getByAddress(Array(127, 0, 0, 1)) var daemonPort: Int = 0 + var daemonWorkers = new mutable.WeakHashMap[Socket, Int]() + + var simpleWorkers = new mutable.WeakHashMap[Socket, Process]() val pythonPath = PythonUtils.mergePythonPaths( PythonUtils.sparkPythonPath, @@ -58,19 +63,31 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String * to avoid the high cost of forking from Java. This currently only works on UNIX-based systems. */ private def createThroughDaemon(): Socket = { + + def createSocket(): Socket = { + val socket = new Socket(daemonHost, daemonPort) + val pid = new DataInputStream(socket.getInputStream).readInt() + if (pid < 0) { + throw new IllegalStateException("Python daemon failed to launch worker") + } + daemonWorkers.put(socket, pid) + socket + } + synchronized { // Start the daemon if it hasn't been started startDaemon() // Attempt to connect, restart and retry once if it fails try { - new Socket(daemonHost, daemonPort) + createSocket() } catch { case exc: SocketException => - logWarning("Python daemon unexpectedly quit, attempting to restart") + logWarning("Failed to open socket to Python daemon:", exc) + logWarning("Assuming that daemon unexpectedly quit, attempting to restart") stopDaemon() startDaemon() - new Socket(daemonHost, daemonPort) + createSocket() } } } @@ -101,7 +118,9 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String // Wait for it to connect to our socket serverSocket.setSoTimeout(10000) try { - return serverSocket.accept() + val socket = serverSocket.accept() + simpleWorkers.put(socket, worker) + return socket } catch { case e: Exception => throw new SparkException("Python worker did not connect back in time", e) @@ -183,19 +202,40 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String private def stopDaemon() { synchronized { - // Request shutdown of existing daemon by sending SIGTERM - if (daemon != null) { - daemon.destroy() + if (useDaemon) { + // Request shutdown of existing daemon by sending SIGTERM + if (daemon != null) { + daemon.destroy() + } + + daemon = null + daemonPort = 0 + } else { + simpleWorkers.mapValues(_.destroy()) } - - daemon = null - daemonPort = 0 } } def stop() { stopDaemon() } + + def stopWorker(worker: Socket) { + if (useDaemon) { + if (daemon != null) { + daemonWorkers.get(worker).foreach { pid => + // tell daemon to kill worker by pid + val output = new DataOutputStream(daemon.getOutputStream) + output.writeInt(pid) + output.flush() + daemon.getOutputStream.flush() + } + } + } else { + simpleWorkers.get(worker).foreach(_.destroy()) + } + worker.close() + } } private object PythonWorkerFactory { diff --git a/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala b/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala index 9a012e7254901..efc9009c088a8 100644 --- a/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala +++ b/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala @@ -17,13 +17,14 @@ package org.apache.spark.api.python -import scala.util.Try -import org.apache.spark.rdd.RDD -import org.apache.spark.Logging -import scala.util.Success +import scala.collection.JavaConversions._ import scala.util.Failure -import net.razorvine.pickle.Pickler +import scala.util.Try +import net.razorvine.pickle.{Unpickler, Pickler} + +import org.apache.spark.{Logging, SparkException} +import org.apache.spark.rdd.RDD /** Utilities for serialization / deserialization between Python and Java, using Pickle. */ private[python] object SerDeUtil extends Logging { @@ -65,20 +66,52 @@ private[python] object SerDeUtil extends Logging { * by PySpark. By default, if serialization fails, toString is called and the string * representation is serialized */ - def rddToPython(rdd: RDD[(Any, Any)]): RDD[Array[Byte]] = { + def pairRDDToPython(rdd: RDD[(Any, Any)], batchSize: Int): RDD[Array[Byte]] = { val (keyFailed, valueFailed) = checkPickle(rdd.first()) rdd.mapPartitions { iter => val pickle = new Pickler - iter.map { case (k, v) => - if (keyFailed && valueFailed) { - pickle.dumps(Array(k.toString, v.toString)) - } else if (keyFailed) { - pickle.dumps(Array(k.toString, v)) - } else if (!keyFailed && valueFailed) { - pickle.dumps(Array(k, v.toString)) + val cleaned = iter.map { case (k, v) => + val key = if (keyFailed) k.toString else k + val value = if (valueFailed) v.toString else v + Array[Any](key, value) + } + if (batchSize > 1) { + cleaned.grouped(batchSize).map(batched => pickle.dumps(seqAsJavaList(batched))) + } else { + cleaned.map(pickle.dumps(_)) + } + } + } + + /** + * Convert an RDD of serialized Python tuple (K, V) to RDD[(K, V)]. + */ + def pythonToPairRDD[K, V](pyRDD: RDD[Array[Byte]], batchSerialized: Boolean): RDD[(K, V)] = { + def isPair(obj: Any): Boolean = { + Option(obj.getClass.getComponentType).map(!_.isPrimitive).getOrElse(false) && + obj.asInstanceOf[Array[_]].length == 2 + } + pyRDD.mapPartitions { iter => + val unpickle = new Unpickler + val unpickled = + if (batchSerialized) { + iter.flatMap { batch => + unpickle.loads(batch) match { + case objs: java.util.List[_] => collectionAsScalaIterable(objs) + case other => throw new SparkException( + s"Unexpected type ${other.getClass.getName} for batch serialized Python RDD") + } + } } else { - pickle.dumps(Array(k, v)) + iter.map(unpickle.loads(_)) } + unpickled.map { + case obj if isPair(obj) => + // we only accept (K, V) + val arr = obj.asInstanceOf[Array[_]] + (arr.head.asInstanceOf[K], arr.last.asInstanceOf[V]) + case other => throw new SparkException( + s"RDD element of type ${other.getClass.getName} cannot be used") } } } diff --git a/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala b/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala index f0e3fb9aff5a0..d11db978b842e 100644 --- a/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala +++ b/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala @@ -17,15 +17,16 @@ package org.apache.spark.api.python -import org.apache.spark.SparkContext -import org.apache.hadoop.io._ -import scala.Array import java.io.{DataOutput, DataInput} +import java.nio.charset.Charset + +import org.apache.hadoop.io._ import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat import org.apache.spark.api.java.JavaSparkContext +import org.apache.spark.{SparkContext, SparkException} /** - * A class to test MsgPack serialization on the Scala side, that will be deserialized + * A class to test Pyrolite serialization on the Scala side, that will be deserialized * in Python * @param str * @param int @@ -54,7 +55,13 @@ case class TestWritable(var str: String, var int: Int, var double: Double) exten } } -class TestConverter extends Converter[Any, Any] { +private[python] class TestInputKeyConverter extends Converter[Any, Any] { + override def convert(obj: Any) = { + obj.asInstanceOf[IntWritable].get().toChar + } +} + +private[python] class TestInputValueConverter extends Converter[Any, Any] { import collection.JavaConversions._ override def convert(obj: Any) = { val m = obj.asInstanceOf[MapWritable] @@ -62,6 +69,38 @@ class TestConverter extends Converter[Any, Any] { } } +private[python] class TestOutputKeyConverter extends Converter[Any, Any] { + override def convert(obj: Any) = { + new Text(obj.asInstanceOf[Int].toString) + } +} + +private[python] class TestOutputValueConverter extends Converter[Any, Any] { + import collection.JavaConversions._ + override def convert(obj: Any) = { + new DoubleWritable(obj.asInstanceOf[java.util.Map[Double, _]].keySet().head) + } +} + +private[python] class DoubleArrayWritable extends ArrayWritable(classOf[DoubleWritable]) + +private[python] class DoubleArrayToWritableConverter extends Converter[Any, Writable] { + override def convert(obj: Any) = obj match { + case arr if arr.getClass.isArray && arr.getClass.getComponentType == classOf[Double] => + val daw = new DoubleArrayWritable + daw.set(arr.asInstanceOf[Array[Double]].map(new DoubleWritable(_))) + daw + case other => throw new SparkException(s"Data of type $other is not supported") + } +} + +private[python] class WritableToDoubleArrayConverter extends Converter[Any, Array[Double]] { + override def convert(obj: Any): Array[Double] = obj match { + case daw : DoubleArrayWritable => daw.get().map(_.asInstanceOf[DoubleWritable].get()) + case other => throw new SparkException(s"Data of type $other is not supported") + } +} + /** * This object contains method to generate SequenceFile test data and write it to a * given directory (probably a temp directory) @@ -97,7 +136,8 @@ object WriteInputFormatTestDataGenerator { sc.parallelize(intKeys).saveAsSequenceFile(intPath) sc.parallelize(intKeys.map{ case (k, v) => (k.toDouble, v) }).saveAsSequenceFile(doublePath) sc.parallelize(intKeys.map{ case (k, v) => (k.toString, v) }).saveAsSequenceFile(textPath) - sc.parallelize(intKeys.map{ case (k, v) => (k, v.getBytes) }).saveAsSequenceFile(bytesPath) + sc.parallelize(intKeys.map{ case (k, v) => (k, v.getBytes(Charset.forName("UTF-8"))) } + ).saveAsSequenceFile(bytesPath) val bools = Seq((1, true), (2, true), (2, false), (3, true), (2, false), (1, false)) sc.parallelize(bools).saveAsSequenceFile(boolPath) sc.parallelize(intKeys).map{ case (k, v) => @@ -106,19 +146,20 @@ object WriteInputFormatTestDataGenerator { // Create test data for ArrayWritable val data = Seq( - (1, Array(1.0, 2.0, 3.0)), + (1, Array()), (2, Array(3.0, 4.0, 5.0)), (3, Array(4.0, 5.0, 6.0)) ) sc.parallelize(data, numSlices = 2) .map{ case (k, v) => - (new IntWritable(k), new ArrayWritable(classOf[DoubleWritable], v.map(new DoubleWritable(_)))) - }.saveAsNewAPIHadoopFile[SequenceFileOutputFormat[IntWritable, ArrayWritable]](arrPath) + val va = new DoubleArrayWritable + va.set(v.map(new DoubleWritable(_))) + (new IntWritable(k), va) + }.saveAsNewAPIHadoopFile[SequenceFileOutputFormat[IntWritable, DoubleArrayWritable]](arrPath) // Create test data for MapWritable, with keys DoubleWritable and values Text val mapData = Seq( - (1, Map(2.0 -> "aa")), - (2, Map(3.0 -> "bb")), + (1, Map()), (2, Map(1.0 -> "cc")), (3, Map(2.0 -> "dd")), (2, Map(1.0 -> "aa")), @@ -126,9 +167,9 @@ object WriteInputFormatTestDataGenerator { ) sc.parallelize(mapData, numSlices = 2).map{ case (i, m) => val mw = new MapWritable() - val k = m.keys.head - val v = m.values.head - mw.put(new DoubleWritable(k), new Text(v)) + m.foreach { case (k, v) => + mw.put(new DoubleWritable(k), new Text(v)) + } (new IntWritable(i), mw) }.saveAsSequenceFile(mapPath) diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala index 487456467b23b..942dc7d7eac87 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala @@ -152,7 +152,8 @@ private[broadcast] object HttpBroadcast extends Logging { private def createServer(conf: SparkConf) { broadcastDir = Utils.createTempDir(Utils.getLocalDir(conf)) - server = new HttpServer(broadcastDir, securityManager) + val broadcastPort = conf.getInt("spark.broadcast.port", 0) + server = new HttpServer(broadcastDir, securityManager, broadcastPort, "HTTP broadcast server") server.start() serverUri = server.uri logInfo("Broadcast server started at " + serverUri) diff --git a/core/src/main/scala/org/apache/spark/deploy/ApplicationDescription.scala b/core/src/main/scala/org/apache/spark/deploy/ApplicationDescription.scala index 86305d2ea8a09..65a1a8fd7e929 100644 --- a/core/src/main/scala/org/apache/spark/deploy/ApplicationDescription.scala +++ b/core/src/main/scala/org/apache/spark/deploy/ApplicationDescription.scala @@ -22,7 +22,6 @@ private[spark] class ApplicationDescription( val maxCores: Option[Int], val memoryPerSlave: Int, val command: Command, - val sparkHome: Option[String], var appUiUrl: String, val eventLogDir: Option[String] = None) extends Serializable { diff --git a/core/src/main/scala/org/apache/spark/deploy/Client.scala b/core/src/main/scala/org/apache/spark/deploy/Client.scala index c371dc3a51c73..c07003784e8ac 100644 --- a/core/src/main/scala/org/apache/spark/deploy/Client.scala +++ b/core/src/main/scala/org/apache/spark/deploy/Client.scala @@ -17,8 +17,6 @@ package org.apache.spark.deploy -import scala.collection.JavaConversions._ -import scala.collection.mutable.Map import scala.concurrent._ import akka.actor._ @@ -50,9 +48,6 @@ private class ClientActor(driverArgs: ClientArguments, conf: SparkConf) extends // TODO: We could add an env variable here and intercept it in `sc.addJar` that would // truncate filesystem paths similar to what YARN does. For now, we just require // people call `addJar` assuming the jar is in the same directory. - val env = Map[String, String]() - System.getenv().foreach{case (k, v) => env(k) = v} - val mainClass = "org.apache.spark.deploy.worker.DriverWrapper" val classPathConf = "spark.driver.extraClassPath" @@ -65,10 +60,13 @@ private class ClientActor(driverArgs: ClientArguments, conf: SparkConf) extends cp.split(java.io.File.pathSeparator) } - val javaOptionsConf = "spark.driver.extraJavaOptions" - val javaOpts = sys.props.get(javaOptionsConf) + val extraJavaOptsConf = "spark.driver.extraJavaOptions" + val extraJavaOpts = sys.props.get(extraJavaOptsConf) + .map(Utils.splitCommandString).getOrElse(Seq.empty) + val sparkJavaOpts = Utils.sparkJavaOpts(conf) + val javaOpts = sparkJavaOpts ++ extraJavaOpts val command = new Command(mainClass, Seq("{{WORKER_URL}}", driverArgs.mainClass) ++ - driverArgs.driverOptions, env, classPathEntries, libraryPathEntries, javaOpts) + driverArgs.driverOptions, sys.env, classPathEntries, libraryPathEntries, javaOpts) val driverDescription = new DriverDescription( driverArgs.jarUrl, @@ -109,6 +107,7 @@ private class ClientActor(driverArgs: ClientArguments, conf: SparkConf) extends // Exception, if present statusResponse.exception.map { e => println(s"Exception from cluster was: $e") + e.printStackTrace() System.exit(-1) } System.exit(0) @@ -141,8 +140,10 @@ private class ClientActor(driverArgs: ClientArguments, conf: SparkConf) extends */ object Client { def main(args: Array[String]) { - println("WARNING: This client is deprecated and will be removed in a future version of Spark.") - println("Use ./bin/spark-submit with \"--master spark://host:port\"") + if (!sys.props.contains("SPARK_SUBMIT")) { + println("WARNING: This client is deprecated and will be removed in a future version of Spark") + println("Use ./bin/spark-submit with \"--master spark://host:port\"") + } val conf = new SparkConf() val driverArgs = new ClientArguments(args) @@ -154,8 +155,6 @@ object Client { conf.set("akka.loglevel", driverArgs.logLevel.toString.replace("WARN", "WARNING")) Logger.getRootLogger.setLevel(driverArgs.logLevel) - // TODO: See if we can initialize akka so return messages are sent back using the same TCP - // flow. Else, this (sadly) requires the DriverClient be routable from the Master. val (actorSystem, _) = AkkaUtils.createActorSystem( "driverClient", Utils.localHostName(), 0, conf, new SecurityManager(conf)) diff --git a/core/src/main/scala/org/apache/spark/deploy/Command.scala b/core/src/main/scala/org/apache/spark/deploy/Command.scala index 32f3ba385084f..a2b263544c6a2 100644 --- a/core/src/main/scala/org/apache/spark/deploy/Command.scala +++ b/core/src/main/scala/org/apache/spark/deploy/Command.scala @@ -25,5 +25,5 @@ private[spark] case class Command( environment: Map[String, String], classPathEntries: Seq[String], libraryPathEntries: Seq[String], - extraJavaOptions: Option[String] = None) { + javaOpts: Seq[String]) { } diff --git a/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala index c4f5e294a393e..696f32a6f5730 100644 --- a/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala @@ -56,7 +56,6 @@ private[spark] object JsonProtocol { ("cores" -> obj.maxCores) ~ ("memoryperslave" -> obj.memoryPerSlave) ~ ("user" -> obj.user) ~ - ("sparkhome" -> obj.sparkHome) ~ ("command" -> obj.command.toString) } diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 3b5642b6caa36..318509a67a36f 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -46,6 +46,10 @@ object SparkSubmit { private val CLUSTER = 2 private val ALL_DEPLOY_MODES = CLIENT | CLUSTER + // A special jar name that indicates the class being run is inside of Spark itself, and therefore + // no user jar is needed. + private val SPARK_INTERNAL = "spark-internal" + // Special primary resource names that represent shells rather than application jars. private val SPARK_SHELL = "spark-shell" private val PYSPARK_SHELL = "pyspark-shell" @@ -132,8 +136,6 @@ object SparkSubmit { (clusterManager, deployMode) match { case (MESOS, CLUSTER) => printErrorAndExit("Cluster deploy mode is currently not supported for Mesos clusters.") - case (STANDALONE, CLUSTER) => - printErrorAndExit("Cluster deploy mode is currently not supported for Standalone clusters.") case (_, CLUSTER) if args.isPython => printErrorAndExit("Cluster deploy mode is currently not supported for python applications.") case (_, CLUSTER) if isShell(args.primaryResource) => @@ -166,9 +168,9 @@ object SparkSubmit { val options = List[OptionAssigner]( // All cluster managers - OptionAssigner(args.master, ALL_CLUSTER_MGRS, CLIENT, sysProp = "spark.master"), - OptionAssigner(args.name, ALL_CLUSTER_MGRS, CLIENT, sysProp = "spark.app.name"), - OptionAssigner(args.jars, ALL_CLUSTER_MGRS, CLIENT, sysProp = "spark.jars"), + OptionAssigner(args.master, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, sysProp = "spark.master"), + OptionAssigner(args.name, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, sysProp = "spark.app.name"), + OptionAssigner(args.jars, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, sysProp = "spark.jars"), // Standalone cluster only OptionAssigner(args.driverMemory, STANDALONE, CLUSTER, clOption = "--memory"), @@ -182,7 +184,7 @@ object SparkSubmit { OptionAssigner(args.archives, YARN, CLIENT, sysProp = "spark.yarn.dist.archives"), // Yarn cluster only - OptionAssigner(args.name, YARN, CLUSTER, clOption = "--name", sysProp = "spark.app.name"), + OptionAssigner(args.name, YARN, CLUSTER, clOption = "--name"), OptionAssigner(args.driverMemory, YARN, CLUSTER, clOption = "--driver-memory"), OptionAssigner(args.queue, YARN, CLUSTER, clOption = "--queue"), OptionAssigner(args.numExecutors, YARN, CLUSTER, clOption = "--num-executors"), @@ -199,9 +201,9 @@ object SparkSubmit { sysProp = "spark.driver.extraJavaOptions"), OptionAssigner(args.driverExtraLibraryPath, STANDALONE | YARN, CLUSTER, sysProp = "spark.driver.extraLibraryPath"), - OptionAssigner(args.executorMemory, STANDALONE | MESOS | YARN, CLIENT, + OptionAssigner(args.executorMemory, STANDALONE | MESOS | YARN, ALL_DEPLOY_MODES, sysProp = "spark.executor.memory"), - OptionAssigner(args.totalExecutorCores, STANDALONE | MESOS, CLIENT, + OptionAssigner(args.totalExecutorCores, STANDALONE | MESOS, ALL_DEPLOY_MODES, sysProp = "spark.cores.max"), OptionAssigner(args.files, LOCAL | STANDALONE | MESOS, ALL_DEPLOY_MODES, sysProp = "spark.files") @@ -257,21 +259,26 @@ object SparkSubmit { // In yarn-cluster mode, use yarn.Client as a wrapper around the user class if (clusterManager == YARN && deployMode == CLUSTER) { childMainClass = "org.apache.spark.deploy.yarn.Client" - childArgs += ("--jar", args.primaryResource) + if (args.primaryResource != SPARK_INTERNAL) { + childArgs += ("--jar", args.primaryResource) + } childArgs += ("--class", args.mainClass) if (args.childArgs != null) { args.childArgs.foreach { arg => childArgs += ("--arg", arg) } } } + // Properties given with --conf are superceded by other options, but take precedence over + // properties in the defaults file. + for ((k, v) <- args.sparkProperties) { + sysProps.getOrElseUpdate(k, v) + } + // Read from default spark properties, if any for ((k, v) <- args.getDefaultSparkProperties) { sysProps.getOrElseUpdate(k, v) } - // Spark properties included on command line take precedence - sysProps ++= args.sparkProperties - (childArgs, childClasspath, sysProps, childMainClass) } @@ -332,7 +339,7 @@ object SparkSubmit { * Return whether the given primary resource represents a user jar. */ private def isUserJar(primaryResource: String): Boolean = { - !isShell(primaryResource) && !isPython(primaryResource) + !isShell(primaryResource) && !isPython(primaryResource) && !isInternal(primaryResource) } /** @@ -349,6 +356,10 @@ object SparkSubmit { primaryResource.endsWith(".py") || primaryResource == PYSPARK_SHELL } + private[spark] def isInternal(primaryResource: String): Boolean = { + primaryResource == SPARK_INTERNAL + } + /** * Merge a sequence of comma-separated file lists, some of which may be null to indicate * no files, into a single comma-separated string. diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index 3ab67a43a3b55..9391f24e71ed7 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -58,7 +58,7 @@ private[spark] class SparkSubmitArguments(args: Seq[String]) { val sparkProperties: HashMap[String, String] = new HashMap[String, String]() parseOpts(args.toList) - loadDefaults() + mergeSparkProperties() checkRequiredArguments() /** Return default present in the currently defined defaults file. */ @@ -79,10 +79,23 @@ private[spark] class SparkSubmitArguments(args: Seq[String]) { defaultProperties } - /** Fill in any undefined values based on the current properties file or built-in defaults. */ - private def loadDefaults(): Unit = { - + /** + * Fill in any undefined values based on the default properties file or options passed in through + * the '--conf' flag. + */ + private def mergeSparkProperties(): Unit = { // Use common defaults file, if not specified by user + if (propertiesFile == null) { + sys.env.get("SPARK_CONF_DIR").foreach { sparkConfDir => + val sep = File.separator + val defaultPath = s"${sparkConfDir}${sep}spark-defaults.conf" + val file = new File(defaultPath) + if (file.exists()) { + propertiesFile = file.getAbsolutePath + } + } + } + if (propertiesFile == null) { sys.env.get("SPARK_HOME").foreach { sparkHome => val sep = File.separator @@ -94,18 +107,20 @@ private[spark] class SparkSubmitArguments(args: Seq[String]) { } } - val defaultProperties = getDefaultSparkProperties + val properties = getDefaultSparkProperties + properties.putAll(sparkProperties) + // Use properties file as fallback for values which have a direct analog to // arguments in this script. - master = Option(master).getOrElse(defaultProperties.get("spark.master").orNull) + master = Option(master).getOrElse(properties.get("spark.master").orNull) executorMemory = Option(executorMemory) - .getOrElse(defaultProperties.get("spark.executor.memory").orNull) + .getOrElse(properties.get("spark.executor.memory").orNull) executorCores = Option(executorCores) - .getOrElse(defaultProperties.get("spark.executor.cores").orNull) + .getOrElse(properties.get("spark.executor.cores").orNull) totalExecutorCores = Option(totalExecutorCores) - .getOrElse(defaultProperties.get("spark.cores.max").orNull) - name = Option(name).getOrElse(defaultProperties.get("spark.app.name").orNull) - jars = Option(jars).getOrElse(defaultProperties.get("spark.jars").orNull) + .getOrElse(properties.get("spark.cores.max").orNull) + name = Option(name).getOrElse(properties.get("spark.app.name").orNull) + jars = Option(jars).getOrElse(properties.get("spark.jars").orNull) // This supports env vars in older versions of Spark master = Option(master).getOrElse(System.getenv("MASTER")) @@ -204,8 +219,9 @@ private[spark] class SparkSubmitArguments(args: Seq[String]) { /** Fill in values by parsing user options. */ private def parseOpts(opts: Seq[String]): Unit = { - // Delineates parsing of Spark options from parsing of user options. var inSparkOpts = true + + // Delineates parsing of Spark options from parsing of user options. parse(opts) def parse(opts: Seq[String]): Unit = opts match { @@ -318,7 +334,7 @@ private[spark] class SparkSubmitArguments(args: Seq[String]) { SparkSubmit.printErrorAndExit(errMessage) case v => primaryResource = - if (!SparkSubmit.isShell(v)) { + if (!SparkSubmit.isShell(v) && !SparkSubmit.isInternal(v)) { Utils.resolveURI(v).toString } else { v diff --git a/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala index e15a87bd38fda..88a0862b96afe 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala @@ -46,11 +46,10 @@ private[spark] object TestClient { def main(args: Array[String]) { val url = args(0) val conf = new SparkConf - val (actorSystem, port) = AkkaUtils.createActorSystem("spark", Utils.localIpAddress, 0, + val (actorSystem, _) = AkkaUtils.createActorSystem("spark", Utils.localIpAddress, 0, conf = conf, securityManager = new SecurityManager(conf)) - val desc = new ApplicationDescription( - "TestClient", Some(1), 512, Command("spark.deploy.client.TestExecutor", Seq(), Map(), Seq(), - Seq()), Some("dummy-spark-home"), "ignored") + val desc = new ApplicationDescription("TestClient", Some(1), 512, + Command("spark.deploy.client.TestExecutor", Seq(), Map(), Seq(), Seq(), Seq()), "ignored") val listener = new TestListener val client = new AppClient(actorSystem, Array(url), desc, listener, new SparkConf) client.start() diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index 01e7065c17b69..cc06540ee0647 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -36,11 +36,11 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis conf.getInt("spark.history.updateInterval", 10)) * 1000 private val logDir = conf.get("spark.history.fs.logDirectory", null) - if (logDir == null) { - throw new IllegalArgumentException("Logging directory must be specified.") - } + private val resolvedLogDir = Option(logDir) + .map { d => Utils.resolveURI(d) } + .getOrElse { throw new IllegalArgumentException("Logging directory must be specified.") } - private val fs = Utils.getHadoopFileSystem(logDir) + private val fs = Utils.getHadoopFileSystem(resolvedLogDir) // A timestamp of when the disk was last accessed to check for log updates private var lastLogCheckTimeMs = -1L @@ -76,14 +76,14 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis private def initialize() { // Validate the log directory. - val path = new Path(logDir) + val path = new Path(resolvedLogDir) if (!fs.exists(path)) { throw new IllegalArgumentException( - "Logging directory specified does not exist: %s".format(logDir)) + "Logging directory specified does not exist: %s".format(resolvedLogDir)) } if (!fs.getFileStatus(path).isDir) { throw new IllegalArgumentException( - "Logging directory specified is not a directory: %s".format(logDir)) + "Logging directory specified is not a directory: %s".format(resolvedLogDir)) } checkForLogs() @@ -95,15 +95,16 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis override def getAppUI(appId: String): SparkUI = { try { - val appLogDir = fs.getFileStatus(new Path(logDir, appId)) - loadAppInfo(appLogDir, true)._2 + val appLogDir = fs.getFileStatus(new Path(resolvedLogDir.toString, appId)) + val (_, ui) = loadAppInfo(appLogDir, renderUI = true) + ui } catch { case e: FileNotFoundException => null } } override def getConfig(): Map[String, String] = - Map(("Event Log Location" -> logDir)) + Map("Event Log Location" -> resolvedLogDir.toString) /** * Builds the application list based on the current contents of the log directory. @@ -114,14 +115,14 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis lastLogCheckTimeMs = getMonotonicTimeMs() logDebug("Checking for logs. Time is now %d.".format(lastLogCheckTimeMs)) try { - val logStatus = fs.listStatus(new Path(logDir)) + val logStatus = fs.listStatus(new Path(resolvedLogDir)) val logDirs = if (logStatus != null) logStatus.filter(_.isDir).toSeq else Seq[FileStatus]() - val logInfos = logDirs.filter { - dir => fs.isFile(new Path(dir.getPath(), EventLoggingListener.APPLICATION_COMPLETE)) + val logInfos = logDirs.filter { dir => + fs.isFile(new Path(dir.getPath, EventLoggingListener.APPLICATION_COMPLETE)) } val currentApps = Map[String, ApplicationHistoryInfo]( - appList.map(app => (app.id -> app)):_*) + appList.map(app => app.id -> app):_*) // For any application that either (i) is not listed or (ii) has changed since the last time // the listing was created (defined by the log dir's modification time), load the app's info. @@ -131,7 +132,8 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis val curr = currentApps.getOrElse(dir.getPath().getName(), null) if (curr == null || curr.lastUpdated < getModificationTime(dir)) { try { - newApps += loadAppInfo(dir, false)._1 + val (app, _) = loadAppInfo(dir, renderUI = false) + newApps += app } catch { case e: Exception => logError(s"Failed to load app info from directory $dir.") } @@ -159,9 +161,9 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis * @return A 2-tuple `(app info, ui)`. `ui` will be null if `renderUI` is false. */ private def loadAppInfo(logDir: FileStatus, renderUI: Boolean) = { - val elogInfo = EventLoggingListener.parseLoggingInfo(logDir.getPath(), fs) val path = logDir.getPath val appId = path.getName + val elogInfo = EventLoggingListener.parseLoggingInfo(path, fs) val replayBus = new ReplayListenerBus(elogInfo.logPaths, fs, elogInfo.compressionCodec) val appListener = new ApplicationEventListener replayBus.addListener(appListener) @@ -187,7 +189,9 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis if (ui != null) { val uiAclsEnabled = conf.getBoolean("spark.history.ui.acls.enable", false) - ui.getSecurityManager.setUIAcls(uiAclsEnabled) + ui.getSecurityManager.setAcls(uiAclsEnabled) + // make sure to set admin acls before view acls so properly picked up + ui.getSecurityManager.setAdminAcls(appListener.adminAcls) ui.getSecurityManager.setViewAcls(appListener.sparkUser, appListener.viewAcls) } (appInfo, ui) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala index d7a3e3f120e67..c4ef8b63b0071 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala @@ -45,7 +45,7 @@ private[spark] class HistoryPage(parent: HistoryServer) extends WebUIPage("") {
    - { providerConfig.map(e =>
  • {e._1}: {e._2}
  • ) } + {providerConfig.map { case (k, v) =>
  • {k}: {v}
  • }}
{ if (allApps.size > 0) { diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala index cacb9da8c947b..d1a64c1912cb8 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala @@ -25,9 +25,9 @@ import org.eclipse.jetty.servlet.{ServletContextHandler, ServletHolder} import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.ui.{WebUI, SparkUI, UIUtils} +import org.apache.spark.ui.{SparkUI, UIUtils, WebUI} import org.apache.spark.ui.JettyUtils._ -import org.apache.spark.util.{SignalLogger, Utils} +import org.apache.spark.util.SignalLogger /** * A web server that renders SparkUIs of completed applications. @@ -177,7 +177,7 @@ object HistoryServer extends Logging { def main(argStrings: Array[String]) { SignalLogger.register(log) initSecurity() - val args = new HistoryServerArguments(conf, argStrings) + new HistoryServerArguments(conf, argStrings) val securityManager = new SecurityManager(conf) val providerName = conf.getOption("spark.history.provider") diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala index be9361b754fc3..25fc76c23e0fb 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala @@ -18,7 +18,6 @@ package org.apache.spark.deploy.history import org.apache.spark.SparkConf -import org.apache.spark.util.Utils /** * Command-line parser for the master. @@ -32,6 +31,7 @@ private[spark] class HistoryServerArguments(conf: SparkConf, args: Array[String] args match { case ("--dir" | "-d") :: value :: tail => logDir = value + conf.set("spark.history.fs.logDirectory", value) parse(tail) case ("--help" | "-h") :: tail => @@ -42,9 +42,6 @@ private[spark] class HistoryServerArguments(conf: SparkConf, args: Array[String] case _ => printUsageAndExit(1) } - if (logDir != null) { - conf.set("spark.history.fs.logDirectory", logDir) - } } private def printUsageAndExit(exitCode: Int) { diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 21f8667819c44..a70ecdb375373 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -154,6 +154,8 @@ private[spark] class Master( } override def postStop() { + masterMetricsSystem.report() + applicationMetricsSystem.report() // prevent the CompleteRecovery message sending to restarted master if (recoveryCompletionTask != null) { recoveryCompletionTask.cancel() diff --git a/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala b/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala index a87781fb93850..4b0dbbe543d3f 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala @@ -38,8 +38,8 @@ private[spark] class MasterArguments(args: Array[String], conf: SparkConf) { if (System.getenv("SPARK_MASTER_WEBUI_PORT") != null) { webUiPort = System.getenv("SPARK_MASTER_WEBUI_PORT").toInt } - if (conf.contains("master.ui.port")) { - webUiPort = conf.get("master.ui.port").toInt + if (conf.contains("spark.master.ui.port")) { + webUiPort = conf.get("spark.master.ui.port").toInt } parse(args.toList) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala index 16aa0493370dd..d86ec1e03e45c 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala @@ -28,7 +28,7 @@ import org.apache.spark.util.AkkaUtils */ private[spark] class MasterWebUI(val master: Master, requestedPort: Int) - extends WebUI(master.securityMgr, requestedPort, master.conf) with Logging { + extends WebUI(master.securityMgr, requestedPort, master.conf, name = "MasterUI") with Logging { val masterActorRef = master.self val timeout = AkkaUtils.askTimeout(master.conf) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala b/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala index 4af5bc3afad6c..687e492a0d6fc 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala @@ -47,7 +47,6 @@ object CommandUtils extends Logging { */ def buildJavaOpts(command: Command, memory: Int, sparkHome: String): Seq[String] = { val memoryOpts = Seq(s"-Xms${memory}M", s"-Xmx${memory}M") - val extraOpts = command.extraJavaOptions.map(Utils.splitCommandString).getOrElse(Seq()) // Exists for backwards compatibility with older Spark versions val workerLocalOpts = Option(getenv("SPARK_JAVA_OPTS")).map(Utils.splitCommandString) @@ -62,7 +61,7 @@ object CommandUtils extends Logging { val joined = command.libraryPathEntries.mkString(File.pathSeparator) Seq(s"-Djava.library.path=$joined") } else { - Seq() + Seq() } val permGenOpt = Seq("-XX:MaxPermSize=128m") @@ -71,11 +70,11 @@ object CommandUtils extends Logging { val ext = if (System.getProperty("os.name").startsWith("Windows")) ".cmd" else ".sh" val classPath = Utils.executeAndGetOutput( Seq(sparkHome + "/bin/compute-classpath" + ext), - extraEnvironment=command.environment) + extraEnvironment = command.environment) val userClassPath = command.classPathEntries ++ Seq(classPath) Seq("-cp", userClassPath.filterNot(_.isEmpty).mkString(File.pathSeparator)) ++ - permGenOpt ++ libraryOpts ++ extraOpts ++ workerLocalOpts ++ memoryOpts + permGenOpt ++ libraryOpts ++ workerLocalOpts ++ command.javaOpts ++ memoryOpts } /** Spawn a thread that will redirect a given stream to a file */ diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala index 662d37871e7a6..5caaf6bea3575 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala @@ -36,6 +36,7 @@ import org.apache.spark.deploy.master.DriverState.DriverState /** * Manages the execution of one driver, including automatically restarting the driver on failure. + * This is currently only used in standalone cluster deploy mode. */ private[spark] class DriverRunner( val driverId: String, @@ -81,7 +82,7 @@ private[spark] class DriverRunner( driverDesc.command.environment, classPath, driverDesc.command.libraryPathEntries, - driverDesc.command.extraJavaOptions) + driverDesc.command.javaOpts) val command = CommandUtils.buildCommandSeq(newCommand, driverDesc.mem, sparkHome.getAbsolutePath) launchDriver(command, driverDesc.command.environment, driverDir, driverDesc.supervise) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala index 467317dd9b44c..7be89f9aff0f3 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala @@ -30,6 +30,7 @@ import org.apache.spark.util.logging.FileAppender /** * Manages the execution of one executor process. + * This is currently only used in standalone mode. */ private[spark] class ExecutorRunner( val appId: String, @@ -72,7 +73,7 @@ private[spark] class ExecutorRunner( } /** - * kill executor process, wait for exit and notify worker to update resource status + * Kill executor process, wait for exit and notify worker to update resource status. * * @param message the exception message which caused the executor's death */ @@ -114,10 +115,13 @@ private[spark] class ExecutorRunner( } def getCommandSeq = { - val command = Command(appDesc.command.mainClass, - appDesc.command.arguments.map(substituteVariables) ++ Seq(appId), appDesc.command.environment, - appDesc.command.classPathEntries, appDesc.command.libraryPathEntries, - appDesc.command.extraJavaOptions) + val command = Command( + appDesc.command.mainClass, + appDesc.command.arguments.map(substituteVariables) ++ Seq(appId), + appDesc.command.environment, + appDesc.command.classPathEntries, + appDesc.command.libraryPathEntries, + appDesc.command.javaOpts) CommandUtils.buildCommandSeq(command, memory, sparkHome.getAbsolutePath) } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index ce425443051b0..458d9947bd873 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -71,7 +71,7 @@ private[spark] class Worker( // TTL for app folders/data; after TTL expires it will be cleaned up val APP_DATA_RETENTION_SECS = conf.getLong("spark.worker.cleanup.appDataTtl", 7 * 24 * 3600) - + val testing: Boolean = sys.props.contains("spark.testing") val masterLock: Object = new Object() var master: ActorSelection = null var masterAddress: Address = null @@ -81,7 +81,13 @@ private[spark] class Worker( @volatile var registered = false @volatile var connected = false val workerId = generateWorkerId() - val sparkHome = new File(Option(System.getenv("SPARK_HOME")).getOrElse(".")) + val sparkHome = + if (testing) { + assert(sys.props.contains("spark.test.home"), "spark.test.home is not set!") + new File(sys.props("spark.test.home")) + } else { + new File(sys.env.get("SPARK_HOME").getOrElse(".")) + } var workDir: File = null val executors = new HashMap[String, ExecutorRunner] val finishedExecutors = new HashMap[String, ExecutorRunner] @@ -233,9 +239,7 @@ private[spark] class Worker( try { logInfo("Asked to launch executor %s/%d for %s".format(appId, execId, appDesc.name)) val manager = new ExecutorRunner(appId, execId, appDesc, cores_, memory_, - self, workerId, host, - appDesc.sparkHome.map(userSparkHome => new File(userSparkHome)).getOrElse(sparkHome), - workDir, akkaUrl, conf, ExecutorState.RUNNING) + self, workerId, host, sparkHome, workDir, akkaUrl, conf, ExecutorState.RUNNING) executors(appId + "/" + execId) = manager manager.start() coresUsed += cores_ @@ -357,6 +361,7 @@ private[spark] class Worker( } override def postStop() { + metricsSystem.report() registrationRetryTimer.foreach(_.cancel()) executors.values.foreach(_.kill()) drivers.values.foreach(_.kill()) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala index b389cb546de6c..ecb358c399819 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala @@ -17,7 +17,6 @@ package org.apache.spark.deploy.worker.ui -import java.io.File import javax.servlet.http.HttpServletRequest import scala.xml.Node @@ -25,7 +24,7 @@ import scala.xml.Node import org.apache.spark.ui.{WebUIPage, UIUtils} import org.apache.spark.util.Utils import org.apache.spark.Logging -import org.apache.spark.util.logging.{FileAppender, RollingFileAppender} +import org.apache.spark.util.logging.RollingFileAppender private[spark] class LogPage(parent: WorkerWebUI) extends WebUIPage("logPage") with Logging { private val worker = parent.worker @@ -64,11 +63,11 @@ private[spark] class LogPage(parent: WorkerWebUI) extends WebUIPage("logPage") w val offset = Option(request.getParameter("offset")).map(_.toLong) val byteLength = Option(request.getParameter("byteLength")).map(_.toInt).getOrElse(defaultBytes) - val (logDir, params) = (appId, executorId, driverId) match { + val (logDir, params, pageName) = (appId, executorId, driverId) match { case (Some(a), Some(e), None) => - (s"${workDir.getPath}/$a/$e/", s"appId=$a&executorId=$e") + (s"${workDir.getPath}/$a/$e/", s"appId=$a&executorId=$e", s"$a/$e") case (None, None, Some(d)) => - (s"${workDir.getPath}/$d/", s"driverId=$d") + (s"${workDir.getPath}/$d/", s"driverId=$d", d) case _ => throw new Exception("Request must specify either application or driver identifiers") } @@ -120,7 +119,7 @@ private[spark] class LogPage(parent: WorkerWebUI) extends WebUIPage("logPage") w
- UIUtils.basicSparkPage(content, logType + " log page for " + appId.getOrElse("unknown app")) + UIUtils.basicSparkPage(content, logType + " log page for " + pageName) } /** Get the part of the log files given the offset and desired length of bytes */ diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala index 0ad2edba2227f..47fbda600bea7 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala @@ -22,6 +22,7 @@ import javax.servlet.http.HttpServletRequest import org.apache.spark.{Logging, SparkConf} import org.apache.spark.deploy.worker.Worker +import org.apache.spark.deploy.worker.ui.WorkerWebUI._ import org.apache.spark.ui.{SparkUI, WebUI} import org.apache.spark.ui.JettyUtils._ import org.apache.spark.util.AkkaUtils @@ -34,7 +35,7 @@ class WorkerWebUI( val worker: Worker, val workDir: File, port: Option[Int] = None) - extends WebUI(worker.securityMgr, WorkerWebUI.getUIPort(port, worker.conf), worker.conf) + extends WebUI(worker.securityMgr, getUIPort(port, worker.conf), worker.conf, name = "WorkerUI") with Logging { val timeout = AkkaUtils.askTimeout(worker.conf) @@ -58,6 +59,6 @@ private[spark] object WorkerWebUI { val STATIC_RESOURCE_BASE = SparkUI.STATIC_RESOURCE_DIR def getUIPort(requestedPort: Option[Int], conf: SparkConf): Int = { - requestedPort.getOrElse(conf.getInt("worker.ui.port", WorkerWebUI.DEFAULT_PORT)) + requestedPort.getOrElse(conf.getInt("spark.worker.ui.port", WorkerWebUI.DEFAULT_PORT)) } } diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index b455c9fcf4bd6..1f46a0f176490 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -88,6 +88,7 @@ private[spark] class CoarseGrainedExecutorBackend( case StopExecutor => logInfo("Driver commanded a shutdown") + executor.stop() context.stop(self) context.system.shutdown() } @@ -98,8 +99,13 @@ private[spark] class CoarseGrainedExecutorBackend( } private[spark] object CoarseGrainedExecutorBackend extends Logging { - def run(driverUrl: String, executorId: String, hostname: String, cores: Int, - workerUrl: Option[String]) { + + private def run( + driverUrl: String, + executorId: String, + hostname: String, + cores: Int, + workerUrl: Option[String]) { SignalLogger.register(log) @@ -109,8 +115,9 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { // Bootstrap to fetch the driver's Spark properties. val executorConf = new SparkConf + val port = executorConf.getInt("spark.executor.port", 0) val (fetcher, _) = AkkaUtils.createActorSystem( - "driverPropsFetcher", hostname, 0, executorConf, new SecurityManager(executorConf)) + "driverPropsFetcher", hostname, port, executorConf, new SecurityManager(executorConf)) val driver = fetcher.actorSelection(driverUrl) val timeout = AkkaUtils.askTimeout(executorConf) val fut = Patterns.ask(driver, RetrieveSparkProps, timeout) @@ -120,7 +127,7 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { // Create a new ActorSystem using driver's Spark properties to run the backend. val driverConf = new SparkConf().setAll(props) val (actorSystem, boundPort) = AkkaUtils.createActorSystem( - "sparkExecutor", hostname, 0, driverConf, new SecurityManager(driverConf)) + "sparkExecutor", hostname, port, driverConf, new SecurityManager(driverConf)) // set it val sparkHostPort = hostname + ":" + boundPort actorSystem.actorOf( diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 3b69bc4ca4142..c2b9c660ddaec 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -23,7 +23,7 @@ import java.nio.ByteBuffer import java.util.concurrent._ import scala.collection.JavaConversions._ -import scala.collection.mutable.HashMap +import scala.collection.mutable.{ArrayBuffer, HashMap} import org.apache.spark._ import org.apache.spark.scheduler._ @@ -48,6 +48,8 @@ private[spark] class Executor( private val EMPTY_BYTE_BUFFER = ByteBuffer.wrap(new Array[Byte](0)) + @volatile private var isStopped = false + // No ip or host:port - just hostname Utils.checkHost(slaveHostname, "Expected executed slave to be a hostname") // must not have port specified. @@ -107,6 +109,8 @@ private[spark] class Executor( // Maintains the list of running tasks. private val runningTasks = new ConcurrentHashMap[Long, TaskRunner] + startDriverHeartbeater() + def launchTask( context: ExecutorBackend, taskId: Long, taskName: String, serializedTask: ByteBuffer) { val tr = new TaskRunner(context, taskId, taskName, serializedTask) @@ -121,6 +125,12 @@ private[spark] class Executor( } } + def stop() { + env.metricsSystem.report() + isStopped = true + threadPool.shutdown() + } + /** Get the Yarn approved local directories. */ private def getYarnLocalDirs(): String = { // Hadoop 0.23 and 2.x have different Environment variable names for the @@ -137,11 +147,12 @@ private[spark] class Executor( } class TaskRunner( - execBackend: ExecutorBackend, taskId: Long, taskName: String, serializedTask: ByteBuffer) + execBackend: ExecutorBackend, val taskId: Long, taskName: String, serializedTask: ByteBuffer) extends Runnable { @volatile private var killed = false - @volatile private var task: Task[Any] = _ + @volatile var task: Task[Any] = _ + @volatile var attemptedTask: Option[Task[Any]] = None def kill(interruptThread: Boolean) { logInfo(s"Executor is trying to kill $taskName (TID $taskId)") @@ -158,7 +169,6 @@ private[spark] class Executor( val ser = SparkEnv.get.closureSerializer.newInstance() logInfo(s"Running $taskName (TID $taskId)") execBackend.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER) - var attemptedTask: Option[Task[Any]] = None var taskStart: Long = 0 def gcTime = ManagementFactory.getGarbageCollectorMXBeans.map(_.getCollectionTime).sum val startGCTime = gcTime @@ -200,7 +210,6 @@ private[spark] class Executor( val afterSerialization = System.currentTimeMillis() for (m <- task.metrics) { - m.hostname = Utils.localHostName() m.executorDeserializeTime = taskStart - startTime m.executorRunTime = taskFinish - taskStart m.jvmGCTime = gcTime - startGCTime @@ -267,10 +276,7 @@ private[spark] class Executor( } } finally { // Release memory used by this thread for shuffles - val shuffleMemoryMap = env.shuffleMemoryMap - shuffleMemoryMap.synchronized { - shuffleMemoryMap.remove(Thread.currentThread().getId) - } + env.shuffleMemoryManager.releaseMemoryForThisThread() // Release memory used by this thread for unrolling blocks env.blockManager.memoryStore.releaseUnrollMemoryForThisThread() runningTasks.remove(taskId) @@ -350,4 +356,42 @@ private[spark] class Executor( } } } + + def startDriverHeartbeater() { + val interval = conf.getInt("spark.executor.heartbeatInterval", 10000) + val timeout = AkkaUtils.lookupTimeout(conf) + val retryAttempts = AkkaUtils.numRetries(conf) + val retryIntervalMs = AkkaUtils.retryWaitMs(conf) + val heartbeatReceiverRef = AkkaUtils.makeDriverRef("HeartbeatReceiver", conf, env.actorSystem) + + val t = new Thread() { + override def run() { + // Sleep a random interval so the heartbeats don't end up in sync + Thread.sleep(interval + (math.random * interval).asInstanceOf[Int]) + + while (!isStopped) { + val tasksMetrics = new ArrayBuffer[(Long, TaskMetrics)]() + for (taskRunner <- runningTasks.values()) { + if (!taskRunner.attemptedTask.isEmpty) { + Option(taskRunner.task).flatMap(_.metrics).foreach { metrics => + tasksMetrics += ((taskRunner.taskId, metrics)) + } + } + } + + val message = Heartbeat(executorId, tasksMetrics.toArray, env.blockManager.blockManagerId) + val response = AkkaUtils.askWithReply[HeartbeatResponse](message, heartbeatReceiverRef, + retryAttempts, retryIntervalMs, timeout) + if (response.reregisterBlockManager) { + logWarning("Told to re-register on heartbeat") + env.blockManager.reregister() + } + Thread.sleep(interval) + } + } + } + t.setDaemon(true) + t.setName("Driver Heartbeater") + t.start() + } } diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala index 21fe643b8d71f..56cd8723a3a22 100644 --- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala @@ -23,6 +23,14 @@ import org.apache.spark.storage.{BlockId, BlockStatus} /** * :: DeveloperApi :: * Metrics tracked during the execution of a task. + * + * This class is used to house metrics both for in-progress and completed tasks. In executors, + * both the task thread and the heartbeat thread write to the TaskMetrics. The heartbeat thread + * reads it to send in-progress metrics, and the task thread reads it to send metrics along with + * the completed task. + * + * So, when adding new fields, take into consideration that the whole object can be serialized for + * shipping off at any time to consumers of the SparkListener interface. */ @DeveloperApi class TaskMetrics extends Serializable { @@ -143,7 +151,7 @@ class ShuffleReadMetrics extends Serializable { /** * Absolute time when this task finished reading shuffle data */ - var shuffleFinishTime: Long = _ + var shuffleFinishTime: Long = -1 /** * Number of blocks fetched in this shuffle by this task (remote or local) diff --git a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala index 651511da1b7fe..6ef817d0e587e 100644 --- a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala +++ b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala @@ -91,6 +91,10 @@ private[spark] class MetricsSystem private (val instance: String, sinks.foreach(_.stop) } + def report(): Unit = { + sinks.foreach(_.report()) + } + def registerSource(source: Source) { sources += source try { diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/ConsoleSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/ConsoleSink.scala index 05852f1f98993..81b9056b40fb8 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/ConsoleSink.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/ConsoleSink.scala @@ -57,5 +57,9 @@ private[spark] class ConsoleSink(val property: Properties, val registry: MetricR override def stop() { reporter.stop() } + + override def report() { + reporter.report() + } } diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/CsvSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/CsvSink.scala index 542dce65366b2..9d5f2ae9328ad 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/CsvSink.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/CsvSink.scala @@ -66,5 +66,9 @@ private[spark] class CsvSink(val property: Properties, val registry: MetricRegis override def stop() { reporter.stop() } + + override def report() { + reporter.report() + } } diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala index aeb4ad44a0647..d7b5f5c40efae 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala @@ -81,4 +81,8 @@ private[spark] class GraphiteSink(val property: Properties, val registry: Metric override def stop() { reporter.stop() } + + override def report() { + reporter.report() + } } diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/JmxSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/JmxSink.scala index ed27234b4e760..2588fe2c9edb8 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/JmxSink.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/JmxSink.scala @@ -35,4 +35,6 @@ private[spark] class JmxSink(val property: Properties, val registry: MetricRegis reporter.stop() } + override def report() { } + } diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala b/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala index 571539ba5e467..2f65bc8b46609 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala @@ -57,4 +57,6 @@ private[spark] class MetricsServlet(val property: Properties, val registry: Metr override def start() { } override def stop() { } + + override def report() { } } diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/Sink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/Sink.scala index 6f2b5a06027ea..0d83d8c425ca4 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/Sink.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/Sink.scala @@ -20,4 +20,5 @@ package org.apache.spark.metrics.sink private[spark] trait Sink { def start: Unit def stop: Unit + def report(): Unit } diff --git a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala index 566e8a4aaa1d2..4c00225280cce 100644 --- a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala +++ b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala @@ -38,8 +38,12 @@ import scala.language.postfixOps import org.apache.spark._ import org.apache.spark.util.{SystemClock, Utils} -private[spark] class ConnectionManager(port: Int, conf: SparkConf, - securityManager: SecurityManager) extends Logging { +private[spark] class ConnectionManager( + port: Int, + conf: SparkConf, + securityManager: SecurityManager, + name: String = "Connection manager") + extends Logging { class MessageStatus( val message: Message, @@ -105,7 +109,11 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf, serverChannel.socket.setReuseAddress(true) serverChannel.socket.setReceiveBufferSize(256 * 1024) - serverChannel.socket.bind(new InetSocketAddress(port)) + private def startService(port: Int): (ServerSocketChannel, Int) = { + serverChannel.socket.bind(new InetSocketAddress(port)) + (serverChannel, serverChannel.socket.getLocalPort) + } + Utils.startServiceOnPort[ServerSocketChannel](port, startService, name) serverChannel.register(selector, SelectionKey.OP_ACCEPT) val id = new ConnectionManagerId(Utils.localHostName, serverChannel.socket.getLocalPort) diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala index 6388ef82cc5db..fabb882cdd4b3 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala @@ -17,10 +17,11 @@ package org.apache.spark.rdd +import scala.language.existentials + import java.io.{IOException, ObjectOutputStream} import scala.collection.mutable.ArrayBuffer -import scala.language.existentials import org.apache.spark.{InterruptibleIterator, Partition, Partitioner, SparkEnv, TaskContext} import org.apache.spark.{Dependency, OneToOneDependency, ShuffleDependency} @@ -157,8 +158,8 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part: for ((it, depNum) <- rddIterators) { map.insertAll(it.map(pair => (pair._1, new CoGroupValue(pair._2, depNum)))) } - context.taskMetrics.memoryBytesSpilled = map.memoryBytesSpilled - context.taskMetrics.diskBytesSpilled = map.diskBytesSpilled + context.taskMetrics.memoryBytesSpilled += map.memoryBytesSpilled + context.taskMetrics.diskBytesSpilled += map.diskBytesSpilled new InterruptibleIterator(context, map.iterator.asInstanceOf[Iterator[(K, Array[Iterable[_]])]]) } diff --git a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala index e7221e3032c11..11ebafbf6d457 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala @@ -49,8 +49,8 @@ private[spark] case class CoalescedRDDPartition( } /** - * Computes how many of the parents partitions have getPreferredLocation - * as one of their preferredLocations + * Computes the fraction of the parents' partitions containing preferredLocation within + * their getPreferredLocs. * @return locality of this coalesced partition between 0 and 1 */ def localFraction: Double = { diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index e521612ffc27c..8d92ea01d9a3f 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -20,7 +20,9 @@ package org.apache.spark.rdd import java.text.SimpleDateFormat import java.util.Date import java.io.EOFException + import scala.collection.immutable.Map +import scala.reflect.ClassTag import org.apache.hadoop.conf.{Configurable, Configuration} import org.apache.hadoop.mapred.FileSplit @@ -39,6 +41,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.executor.{DataReadMethod, InputMetrics} +import org.apache.spark.rdd.HadoopRDD.HadoopMapPartitionsWithSplitRDD import org.apache.spark.util.NextIterator /** @@ -232,6 +235,14 @@ class HadoopRDD[K, V]( new InterruptibleIterator[(K, V)](context, iter) } + /** Maps over a partition, providing the InputSplit that was used as the base of the partition. */ + @DeveloperApi + def mapPartitionsWithInputSplit[U: ClassTag]( + f: (InputSplit, Iterator[(K, V)]) => Iterator[U], + preservesPartitioning: Boolean = false): RDD[U] = { + new HadoopMapPartitionsWithSplitRDD(this, f, preservesPartitioning) + } + override def getPreferredLocations(split: Partition): Seq[String] = { // TODO: Filtering out "localhost" in case of file:// URLs val hadoopSplit = split.asInstanceOf[HadoopPartition] @@ -272,4 +283,25 @@ private[spark] object HadoopRDD { conf.setInt("mapred.task.partition", splitId) conf.set("mapred.job.id", jobID.toString) } + + /** + * Analogous to [[org.apache.spark.rdd.MapPartitionsRDD]], but passes in an InputSplit to + * the given function rather than the index of the partition. + */ + private[spark] class HadoopMapPartitionsWithSplitRDD[U: ClassTag, T: ClassTag]( + prev: RDD[T], + f: (InputSplit, Iterator[T]) => Iterator[U], + preservesPartitioning: Boolean = false) + extends RDD[U](prev) { + + override val partitioner = if (preservesPartitioning) firstParent[T].partitioner else None + + override def getPartitions: Array[Partition] = firstParent[T].partitions + + override def compute(split: Partition, context: TaskContext) = { + val partition = split.asInstanceOf[HadoopPartition] + val inputSplit = partition.inputSplit.value + f(inputSplit, firstParent[T].iterator(split, context)) + } + } } diff --git a/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala b/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala index a76a070b5b863..8947e66f4577c 100644 --- a/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala @@ -96,17 +96,23 @@ class JdbcRDD[T: ClassTag]( override def close() { try { - if (null != rs && ! rs.isClosed()) rs.close() + if (null != rs && ! rs.isClosed()) { + rs.close() + } } catch { case e: Exception => logWarning("Exception closing resultset", e) } try { - if (null != stmt && ! stmt.isClosed()) stmt.close() + if (null != stmt && ! stmt.isClosed()) { + stmt.close() + } } catch { case e: Exception => logWarning("Exception closing statement", e) } try { - if (null != conn && ! stmt.isClosed()) conn.close() + if (null != conn && ! conn.isClosed()) { + conn.close() + } logInfo("closed connection") } catch { case e: Exception => logWarning("Exception closing connection", e) @@ -120,3 +126,4 @@ object JdbcRDD { Array.tabulate[Object](rs.getMetaData.getColumnCount)(i => rs.getObject(i + 1)) } } + diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index f2b3a64bf1345..7dfec9a18ec67 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -20,6 +20,8 @@ package org.apache.spark.rdd import java.text.SimpleDateFormat import java.util.Date +import scala.reflect.ClassTag + import org.apache.hadoop.conf.{Configurable, Configuration} import org.apache.hadoop.io.Writable import org.apache.hadoop.mapreduce._ @@ -32,6 +34,7 @@ import org.apache.spark.Partition import org.apache.spark.SerializableWritable import org.apache.spark.{SparkContext, TaskContext} import org.apache.spark.executor.{DataReadMethod, InputMetrics} +import org.apache.spark.rdd.NewHadoopRDD.NewHadoopMapPartitionsWithSplitRDD private[spark] class NewHadoopPartition( rddId: Int, @@ -157,6 +160,14 @@ class NewHadoopRDD[K, V]( new InterruptibleIterator(context, iter) } + /** Maps over a partition, providing the InputSplit that was used as the base of the partition. */ + @DeveloperApi + def mapPartitionsWithInputSplit[U: ClassTag]( + f: (InputSplit, Iterator[(K, V)]) => Iterator[U], + preservesPartitioning: Boolean = false): RDD[U] = { + new NewHadoopMapPartitionsWithSplitRDD(this, f, preservesPartitioning) + } + override def getPreferredLocations(split: Partition): Seq[String] = { val theSplit = split.asInstanceOf[NewHadoopPartition] theSplit.serializableHadoopSplit.value.getLocations.filter(_ != "localhost") @@ -165,6 +176,29 @@ class NewHadoopRDD[K, V]( def getConf: Configuration = confBroadcast.value.value } +private[spark] object NewHadoopRDD { + /** + * Analogous to [[org.apache.spark.rdd.MapPartitionsRDD]], but passes in an InputSplit to + * the given function rather than the index of the partition. + */ + private[spark] class NewHadoopMapPartitionsWithSplitRDD[U: ClassTag, T: ClassTag]( + prev: RDD[T], + f: (InputSplit, Iterator[T]) => Iterator[U], + preservesPartitioning: Boolean = false) + extends RDD[U](prev) { + + override val partitioner = if (preservesPartitioning) firstParent[T].partitioner else None + + override def getPartitions: Array[Partition] = firstParent[T].partitions + + override def compute(split: Partition, context: TaskContext) = { + val partition = split.asInstanceOf[NewHadoopPartition] + val inputSplit = partition.serializableHadoopSplit.value + f(inputSplit, firstParent[T].iterator(split, context)) + } + } +} + private[spark] class WholeTextFileRDD( sc : SparkContext, inputFormatClass: Class[_ <: WholeTextFileInputFormat], diff --git a/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala index afd7075f686b9..e98bad2026e32 100644 --- a/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala @@ -20,6 +20,7 @@ package org.apache.spark.rdd import scala.reflect.ClassTag import org.apache.spark.{Logging, RangePartitioner} +import org.apache.spark.annotation.DeveloperApi /** * Extra functions available on RDDs of (key, value) pairs where the key is sortable through @@ -43,10 +44,10 @@ import org.apache.spark.{Logging, RangePartitioner} */ class OrderedRDDFunctions[K : Ordering : ClassTag, V: ClassTag, - P <: Product2[K, V] : ClassTag]( + P <: Product2[K, V] : ClassTag] @DeveloperApi() ( self: RDD[P]) - extends Logging with Serializable { - + extends Logging with Serializable +{ private val ordering = implicitly[Ordering[K]] /** @@ -55,15 +56,12 @@ class OrderedRDDFunctions[K : Ordering : ClassTag, * (in the `save` case, they will be written to multiple `part-X` files in the filesystem, in * order of the keys). */ - def sortByKey(ascending: Boolean = true, numPartitions: Int = self.partitions.size): RDD[P] = { + // TODO: this currently doesn't work on P other than Tuple2! + def sortByKey(ascending: Boolean = true, numPartitions: Int = self.partitions.size) + : RDD[(K, V)] = + { val part = new RangePartitioner(numPartitions, self, ascending) - new ShuffledRDD[K, V, V, P](self, part) - .setKeyOrdering(ordering) - .setSortOrder(if (ascending) SortOrder.ASCENDING else SortOrder.DESCENDING) + new ShuffledRDD[K, V, V](self, part) + .setKeyOrdering(if (ascending) ordering else ordering.reverse) } } - -private[spark] object SortOrder extends Enumeration { - type SortOrder = Value - val ASCENDING, DESCENDING = Value -} diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index c04d162a39616..93af50c0a9cd1 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -19,12 +19,10 @@ package org.apache.spark.rdd import java.nio.ByteBuffer import java.text.SimpleDateFormat -import java.util.Date -import java.util.{HashMap => JHashMap} +import java.util.{Date, HashMap => JHashMap} +import scala.collection.{Map, mutable} import scala.collection.JavaConversions._ -import scala.collection.Map -import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.reflect.ClassTag @@ -34,19 +32,19 @@ import org.apache.hadoop.fs.FileSystem import org.apache.hadoop.io.SequenceFile.CompressionType import org.apache.hadoop.io.compress.CompressionCodec import org.apache.hadoop.mapred.{FileOutputCommitter, FileOutputFormat, JobConf, OutputFormat} -import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat, Job => NewAPIHadoopJob, +import org.apache.hadoop.mapreduce.{Job => NewAPIHadoopJob, OutputFormat => NewOutputFormat, RecordWriter => NewRecordWriter, SparkHadoopMapReduceUtil} -import org.apache.hadoop.mapreduce.lib.output.{FileOutputFormat => NewFileOutputFormat} import org.apache.spark._ -import org.apache.spark.annotation.Experimental -import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.SparkHadoopWriter import org.apache.spark.Partitioner.defaultPartitioner import org.apache.spark.SparkContext._ +import org.apache.spark.annotation.Experimental +import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.partial.{BoundedDouble, PartialResult} import org.apache.spark.serializer.Serializer +import org.apache.spark.util.Utils import org.apache.spark.util.collection.CompactBuffer +import org.apache.spark.util.random.StratifiedSamplingUtils /** * Extra functions available on RDDs of (key, value) pairs through an implicit conversion. @@ -92,7 +90,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) new InterruptibleIterator(context, aggregator.combineValuesByKey(iter, context)) }, preservesPartitioning = true) } else { - new ShuffledRDD[K, V, C, (K, C)](self, partitioner) + new ShuffledRDD[K, V, C](self, partitioner) .setSerializer(serializer) .setAggregator(aggregator) .setMapSideCombine(mapSideCombine) @@ -195,6 +193,41 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) foldByKey(zeroValue, defaultPartitioner(self))(func) } + /** + * Return a subset of this RDD sampled by key (via stratified sampling). + * + * Create a sample of this RDD using variable sampling rates for different keys as specified by + * `fractions`, a key to sampling rate map. + * + * If `exact` is set to false, create the sample via simple random sampling, with one pass + * over the RDD, to produce a sample of size that's approximately equal to the sum of + * math.ceil(numItems * samplingRate) over all key values; otherwise, use + * additional passes over the RDD to create a sample size that's exactly equal to the sum of + * math.ceil(numItems * samplingRate) over all key values with a 99.99% confidence. When sampling + * without replacement, we need one additional pass over the RDD to guarantee sample size; + * when sampling with replacement, we need two additional passes. + * + * @param withReplacement whether to sample with or without replacement + * @param fractions map of specific keys to sampling rates + * @param seed seed for the random number generator + * @param exact whether sample size needs to be exactly math.ceil(fraction * size) per key + * @return RDD containing the sampled subset + */ + def sampleByKey(withReplacement: Boolean, + fractions: Map[K, Double], + exact: Boolean = false, + seed: Long = Utils.random.nextLong): RDD[(K, V)]= { + + require(fractions.values.forall(v => v >= 0.0), "Negative sampling rates.") + + val samplingFunc = if (withReplacement) { + StratifiedSamplingUtils.getPoissonSamplingFunction(self, fractions, exact, seed) + } else { + StratifiedSamplingUtils.getBernoulliSamplingFunction(self, fractions, exact, seed) + } + self.mapPartitionsWithIndex(samplingFunc, preservesPartitioning = true) + } + /** * Merge the values for each key using an associative reduce function. This will also perform * the merging locally on each mapper before sending results to a reducer, similarly to a @@ -392,7 +425,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) if (self.partitioner == Some(partitioner)) { self } else { - new ShuffledRDD[K, V, V, (K, V)](self, partitioner) + new ShuffledRDD[K, V, V](self, partitioner) } } @@ -531,6 +564,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) /** * Return the key-value pairs in this RDD to the master as a Map. + * + * Warning: this doesn't return a multimap (so if you have multiple values to the same key, only + * one value per key is preserved in the map returned) */ def collectAsMap(): Map[K, V] = { val data = self.collect() diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index a6abc49c5359e..e1c49e35abecd 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -35,12 +35,13 @@ import org.apache.spark.Partitioner._ import org.apache.spark.SparkContext._ import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.api.java.JavaRDD +import org.apache.spark.broadcast.Broadcast import org.apache.spark.partial.BoundedDouble import org.apache.spark.partial.CountEvaluator import org.apache.spark.partial.GroupedCountEvaluator import org.apache.spark.partial.PartialResult import org.apache.spark.storage.StorageLevel -import org.apache.spark.util.{BoundedPriorityQueue, CallSite, Utils} +import org.apache.spark.util.{BoundedPriorityQueue, Utils} import org.apache.spark.util.collection.OpenHashMap import org.apache.spark.util.random.{BernoulliSampler, PoissonSampler, SamplingUtils} @@ -331,7 +332,7 @@ abstract class RDD[T: ClassTag]( val distributePartition = (index: Int, items: Iterator[T]) => { var position = (new Random(index)).nextInt(numPartitions) items.map { t => - // Note that the hash code of the key will just be the key itself. The HashPartitioner + // Note that the hash code of the key will just be the key itself. The HashPartitioner // will mod it with the number of total partitions. position = position + 1 (position, t) @@ -340,7 +341,7 @@ abstract class RDD[T: ClassTag]( // include a shuffle step so that our upstream tasks are still distributed new CoalescedRDD( - new ShuffledRDD[Int, T, T, (Int, T)](mapPartitionsWithIndex(distributePartition), + new ShuffledRDD[Int, T, T](mapPartitionsWithIndex(distributePartition), new HashPartitioner(numPartitions)), numPartitions).values } else { @@ -351,8 +352,8 @@ abstract class RDD[T: ClassTag]( /** * Return a sampled subset of this RDD. */ - def sample(withReplacement: Boolean, - fraction: Double, + def sample(withReplacement: Boolean, + fraction: Double, seed: Long = Utils.random.nextLong): RDD[T] = { require(fraction >= 0.0, "Negative fraction value: " + fraction) if (withReplacement) { @@ -1206,16 +1207,12 @@ abstract class RDD[T: ClassTag]( /** * Return whether this RDD has been checkpointed or not */ - def isCheckpointed: Boolean = { - checkpointData.map(_.isCheckpointed).getOrElse(false) - } + def isCheckpointed: Boolean = checkpointData.exists(_.isCheckpointed) /** * Gets the name of the file to which this RDD was checkpointed */ - def getCheckpointFile: Option[String] = { - checkpointData.flatMap(_.getCheckpointFile) - } + def getCheckpointFile: Option[String] = checkpointData.flatMap(_.getCheckpointFile) // ======================================================================= // Other internal methods and fields @@ -1239,6 +1236,23 @@ abstract class RDD[T: ClassTag]( /** The [[org.apache.spark.SparkContext]] that this RDD was created on. */ def context = sc + /** + * Private API for changing an RDD's ClassTag. + * Used for internal Java <-> Scala API compatibility. + */ + private[spark] def retag(cls: Class[T]): RDD[T] = { + val classTag: ClassTag[T] = ClassTag.apply(cls) + this.retag(classTag) + } + + /** + * Private API for changing an RDD's ClassTag. + * Used for internal Java <-> Scala API compatibility. + */ + private[spark] def retag(implicit classTag: ClassTag[T]): RDD[T] = { + this.mapPartitions(identity, preservesPartitioning = true)(classTag) + } + // Avoid handling doCheckpoint multiple times to prevent excessive recursion @transient private var doCheckpointCalled = false diff --git a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala index c3b2a33fb54d0..f67e5f1857979 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala @@ -106,7 +106,6 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T]) cpRDD = Some(newRDD) rdd.markCheckpointed(newRDD) // Update the RDD's dependencies and partitions cpState = Checkpointed - RDDCheckpointData.clearTaskCaches() } logInfo("Done checkpointing RDD " + rdd.id + " to " + path + ", new parent is RDD " + newRDD.id) } @@ -131,9 +130,5 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T]) } } -private[spark] object RDDCheckpointData { - def clearTaskCaches() { - ShuffleMapTask.clearCache() - ResultTask.clearCache() - } -} +// Used for synchronization +private[spark] object RDDCheckpointData diff --git a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala index da4a8c3dc22b1..d9fe6847254fa 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala @@ -21,7 +21,6 @@ import scala.reflect.ClassTag import org.apache.spark._ import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.rdd.SortOrder.SortOrder import org.apache.spark.serializer.Serializer private[spark] class ShuffledRDDPartition(val idx: Int) extends Partition { @@ -38,11 +37,12 @@ private[spark] class ShuffledRDDPartition(val idx: Int) extends Partition { * @tparam V the value class. * @tparam C the combiner class. */ +// TODO: Make this return RDD[Product2[K, C]] or have some way to configure mutable pairs @DeveloperApi -class ShuffledRDD[K, V, C, P <: Product2[K, C] : ClassTag]( +class ShuffledRDD[K, V, C]( @transient var prev: RDD[_ <: Product2[K, V]], part: Partitioner) - extends RDD[P](prev.context, Nil) { + extends RDD[(K, C)](prev.context, Nil) { private var serializer: Option[Serializer] = None @@ -52,41 +52,32 @@ class ShuffledRDD[K, V, C, P <: Product2[K, C] : ClassTag]( private var mapSideCombine: Boolean = false - private var sortOrder: Option[SortOrder] = None - /** Set a serializer for this RDD's shuffle, or null to use the default (spark.serializer) */ - def setSerializer(serializer: Serializer): ShuffledRDD[K, V, C, P] = { + def setSerializer(serializer: Serializer): ShuffledRDD[K, V, C] = { this.serializer = Option(serializer) this } /** Set key ordering for RDD's shuffle. */ - def setKeyOrdering(keyOrdering: Ordering[K]): ShuffledRDD[K, V, C, P] = { + def setKeyOrdering(keyOrdering: Ordering[K]): ShuffledRDD[K, V, C] = { this.keyOrdering = Option(keyOrdering) this } /** Set aggregator for RDD's shuffle. */ - def setAggregator(aggregator: Aggregator[K, V, C]): ShuffledRDD[K, V, C, P] = { + def setAggregator(aggregator: Aggregator[K, V, C]): ShuffledRDD[K, V, C] = { this.aggregator = Option(aggregator) this } /** Set mapSideCombine flag for RDD's shuffle. */ - def setMapSideCombine(mapSideCombine: Boolean): ShuffledRDD[K, V, C, P] = { + def setMapSideCombine(mapSideCombine: Boolean): ShuffledRDD[K, V, C] = { this.mapSideCombine = mapSideCombine this } - /** Set sort order for RDD's sorting. */ - def setSortOrder(sortOrder: SortOrder): ShuffledRDD[K, V, C, P] = { - this.sortOrder = Option(sortOrder) - this - } - override def getDependencies: Seq[Dependency[_]] = { - List(new ShuffleDependency(prev, part, serializer, - keyOrdering, aggregator, mapSideCombine, sortOrder)) + List(new ShuffleDependency(prev, part, serializer, keyOrdering, aggregator, mapSideCombine)) } override val partitioner = Some(part) @@ -95,11 +86,11 @@ class ShuffledRDD[K, V, C, P <: Product2[K, C] : ClassTag]( Array.tabulate[Partition](part.numPartitions)(i => new ShuffledRDDPartition(i)) } - override def compute(split: Partition, context: TaskContext): Iterator[P] = { + override def compute(split: Partition, context: TaskContext): Iterator[(K, C)] = { val dep = dependencies.head.asInstanceOf[ShuffleDependency[K, V, C]] SparkEnv.get.shuffleManager.getReader(dep.shuffleHandle, split.index, split.index + 1, context) .read() - .asInstanceOf[Iterator[P]] + .asInstanceOf[Iterator[(K, C)]] } override def clearDependencies() { diff --git a/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala index 21c6e07d69f90..197167ecad0bd 100644 --- a/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala @@ -25,21 +25,32 @@ import scala.reflect.ClassTag import org.apache.spark.{Dependency, Partition, RangeDependency, SparkContext, TaskContext} import org.apache.spark.annotation.DeveloperApi -private[spark] class UnionPartition[T: ClassTag](idx: Int, rdd: RDD[T], splitIndex: Int) +/** + * Partition for UnionRDD. + * + * @param idx index of the partition + * @param rdd the parent RDD this partition refers to + * @param parentRddIndex index of the parent RDD this partition refers to + * @param parentRddPartitionIndex index of the partition within the parent RDD + * this partition refers to + */ +private[spark] class UnionPartition[T: ClassTag]( + idx: Int, + @transient rdd: RDD[T], + val parentRddIndex: Int, + @transient parentRddPartitionIndex: Int) extends Partition { - var split: Partition = rdd.partitions(splitIndex) - - def iterator(context: TaskContext) = rdd.iterator(split, context) + var parentPartition: Partition = rdd.partitions(parentRddPartitionIndex) - def preferredLocations() = rdd.preferredLocations(split) + def preferredLocations() = rdd.preferredLocations(parentPartition) override val index: Int = idx @throws(classOf[IOException]) private def writeObject(oos: ObjectOutputStream) { // Update the reference to parent split at the time of task serialization - split = rdd.partitions(splitIndex) + parentPartition = rdd.partitions(parentRddPartitionIndex) oos.defaultWriteObject() } } @@ -47,14 +58,14 @@ private[spark] class UnionPartition[T: ClassTag](idx: Int, rdd: RDD[T], splitInd @DeveloperApi class UnionRDD[T: ClassTag]( sc: SparkContext, - @transient var rdds: Seq[RDD[T]]) + var rdds: Seq[RDD[T]]) extends RDD[T](sc, Nil) { // Nil since we implement getDependencies override def getPartitions: Array[Partition] = { val array = new Array[Partition](rdds.map(_.partitions.size).sum) var pos = 0 - for (rdd <- rdds; split <- rdd.partitions) { - array(pos) = new UnionPartition(pos, rdd, split.index) + for ((rdd, rddIndex) <- rdds.zipWithIndex; split <- rdd.partitions) { + array(pos) = new UnionPartition(pos, rdd, rddIndex, split.index) pos += 1 } array @@ -70,9 +81,17 @@ class UnionRDD[T: ClassTag]( deps } - override def compute(s: Partition, context: TaskContext): Iterator[T] = - s.asInstanceOf[UnionPartition[T]].iterator(context) + override def compute(s: Partition, context: TaskContext): Iterator[T] = { + val part = s.asInstanceOf[UnionPartition[T]] + val parentRdd = dependencies(part.parentRddIndex).rdd.asInstanceOf[RDD[T]] + parentRdd.iterator(part.parentPartition, context) + } override def getPreferredLocations(s: Partition): Seq[String] = s.asInstanceOf[UnionPartition[T]].preferredLocations() + + override def clearDependencies() { + super.clearDependencies() + rdds = null + } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala new file mode 100644 index 0000000000000..fa83372bb4d11 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import org.apache.spark.annotation.DeveloperApi + +/** + * :: DeveloperApi :: + * Information about an [[org.apache.spark.Accumulable]] modified during a task or stage. + */ +@DeveloperApi +class AccumulableInfo ( + val id: Long, + val name: String, + val update: Option[String], // represents a partial update within a task + val value: String) { + + override def equals(other: Any): Boolean = other match { + case acc: AccumulableInfo => + this.id == acc.id && this.name == acc.name && + this.update == acc.update && this.value == acc.value + case _ => false + } +} + +object AccumulableInfo { + def apply(id: Long, name: String, update: Option[String], value: String) = + new AccumulableInfo(id, name, update, value) + + def apply(id: Long, name: String, value: String) = new AccumulableInfo(id, name, None, value) +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/ApplicationEventListener.scala b/core/src/main/scala/org/apache/spark/scheduler/ApplicationEventListener.scala index cd5d44ad4a7e6..162158babc35b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ApplicationEventListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ApplicationEventListener.scala @@ -29,7 +29,7 @@ private[spark] class ApplicationEventListener extends SparkListener { var startTime = -1L var endTime = -1L var viewAcls = "" - var enableViewAcls = false + var adminAcls = "" def applicationStarted = startTime != -1 @@ -55,7 +55,7 @@ private[spark] class ApplicationEventListener extends SparkListener { val environmentDetails = environmentUpdate.environmentDetails val allProperties = environmentDetails("Spark Properties").toMap viewAcls = allProperties.getOrElse("spark.ui.view.acls", "") - enableViewAcls = allProperties.getOrElse("spark.ui.acls.enable", "false").toBoolean + adminAcls = allProperties.getOrElse("spark.admin.acls", "") } } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index dc6142ab79d03..430e45ada5808 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -17,11 +17,11 @@ package org.apache.spark.scheduler -import java.io.{NotSerializableException, PrintWriter, StringWriter} +import java.io.NotSerializableException import java.util.Properties import java.util.concurrent.atomic.AtomicInteger -import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map} +import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map, Stack} import scala.concurrent.Await import scala.concurrent.duration._ import scala.language.postfixOps @@ -29,17 +29,18 @@ import scala.reflect.ClassTag import scala.util.control.NonFatal import akka.actor._ -import akka.actor.OneForOneStrategy import akka.actor.SupervisorStrategy.Stop import akka.pattern.ask import akka.util.Timeout import org.apache.spark._ +import org.apache.spark.broadcast.Broadcast import org.apache.spark.executor.TaskMetrics import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult} import org.apache.spark.rdd.RDD -import org.apache.spark.storage.{BlockId, BlockManager, BlockManagerMaster, RDDBlockId} +import org.apache.spark.storage._ import org.apache.spark.util.{CallSite, SystemClock, Clock, Utils} +import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat /** * The high-level scheduling layer that implements stage-oriented scheduling. It computes a DAG of @@ -114,6 +115,10 @@ class DAGScheduler( private val dagSchedulerActorSupervisor = env.actorSystem.actorOf(Props(new DAGSchedulerActorSupervisor(this))) + // A closure serializer that we reuse. + // This is only safe because DAGScheduler runs in a single thread. + private val closureSerializer = SparkEnv.get.closureSerializer.newInstance() + private[scheduler] var eventProcessActor: ActorRef = _ private def initializeEventProcessActor() { @@ -149,6 +154,23 @@ class DAGScheduler( eventProcessActor ! CompletionEvent(task, reason, result, accumUpdates, taskInfo, taskMetrics) } + /** + * Update metrics for in-progress tasks and let the master know that the BlockManager is still + * alive. Return true if the driver knows about the given block manager. Otherwise, return false, + * indicating that the block manager should re-register. + */ + def executorHeartbeatReceived( + execId: String, + taskMetrics: Array[(Long, Int, TaskMetrics)], // (taskId, stageId, metrics) + blockManagerId: BlockManagerId): Boolean = { + listenerBus.post(SparkListenerExecutorMetricsUpdate(execId, taskMetrics)) + implicit val timeout = Timeout(600 seconds) + + Await.result( + blockManagerMaster.driverActor ? BlockManagerHeartbeat(blockManagerId), + timeout.duration).asInstanceOf[Boolean] + } + // Called by TaskScheduler when an executor fails. def executorLost(execId: String) { eventProcessActor ! ExecutorLost(execId) @@ -189,11 +211,15 @@ class DAGScheduler( shuffleToMapStage.get(shuffleDep.shuffleId) match { case Some(stage) => stage case None => + // We are going to register ancestor shuffle dependencies + registerShuffleDependencies(shuffleDep, jobId) + // Then register current shuffleDep val stage = newOrUsedStage( shuffleDep.rdd, shuffleDep.rdd.partitions.size, shuffleDep, jobId, shuffleDep.rdd.creationSite) shuffleToMapStage(shuffleDep.shuffleId) = stage + stage } } @@ -258,6 +284,9 @@ class DAGScheduler( private def getParentStages(rdd: RDD[_], jobId: Int): List[Stage] = { val parents = new HashSet[Stage] val visited = new HashSet[RDD[_]] + // We are manually maintaining a stack here to prevent StackOverflowError + // caused by recursively visiting + val waitingForVisit = new Stack[RDD[_]] def visit(r: RDD[_]) { if (!visited(r)) { visited += r @@ -268,18 +297,69 @@ class DAGScheduler( case shufDep: ShuffleDependency[_, _, _] => parents += getShuffleMapStage(shufDep, jobId) case _ => - visit(dep.rdd) + waitingForVisit.push(dep.rdd) } } } } - visit(rdd) + waitingForVisit.push(rdd) + while (!waitingForVisit.isEmpty) { + visit(waitingForVisit.pop()) + } parents.toList } + // Find ancestor missing shuffle dependencies and register into shuffleToMapStage + private def registerShuffleDependencies(shuffleDep: ShuffleDependency[_, _, _], jobId: Int) = { + val parentsWithNoMapStage = getAncestorShuffleDependencies(shuffleDep.rdd) + while (!parentsWithNoMapStage.isEmpty) { + val currentShufDep = parentsWithNoMapStage.pop() + val stage = + newOrUsedStage( + currentShufDep.rdd, currentShufDep.rdd.partitions.size, currentShufDep, jobId, + currentShufDep.rdd.creationSite) + shuffleToMapStage(currentShufDep.shuffleId) = stage + } + } + + // Find ancestor shuffle dependencies that are not registered in shuffleToMapStage yet + private def getAncestorShuffleDependencies(rdd: RDD[_]): Stack[ShuffleDependency[_, _, _]] = { + val parents = new Stack[ShuffleDependency[_, _, _]] + val visited = new HashSet[RDD[_]] + // We are manually maintaining a stack here to prevent StackOverflowError + // caused by recursively visiting + val waitingForVisit = new Stack[RDD[_]] + def visit(r: RDD[_]) { + if (!visited(r)) { + visited += r + for (dep <- r.dependencies) { + dep match { + case shufDep: ShuffleDependency[_, _, _] => + if (!shuffleToMapStage.contains(shufDep.shuffleId)) { + parents.push(shufDep) + } + + waitingForVisit.push(shufDep.rdd) + case _ => + waitingForVisit.push(dep.rdd) + } + } + } + } + + waitingForVisit.push(rdd) + while (!waitingForVisit.isEmpty) { + visit(waitingForVisit.pop()) + } + parents + } + private def getMissingParentStages(stage: Stage): List[Stage] = { val missing = new HashSet[Stage] val visited = new HashSet[RDD[_]] + // We are manually maintaining a stack here to prevent StackOverflowError + // caused by recursively visiting + val waitingForVisit = new Stack[RDD[_]] def visit(rdd: RDD[_]) { if (!visited(rdd)) { visited += rdd @@ -292,13 +372,16 @@ class DAGScheduler( missing += mapStage } case narrowDep: NarrowDependency[_] => - visit(narrowDep.rdd) + waitingForVisit.push(narrowDep.rdd) } } } } } - visit(stage.rdd) + waitingForVisit.push(stage.rdd) + while (!waitingForVisit.isEmpty) { + visit(waitingForVisit.pop()) + } missing.toList } @@ -361,9 +444,6 @@ class DAGScheduler( // data structures based on StageId stageIdToStage -= stageId - ShuffleMapTask.removeStage(stageId) - ResultTask.removeStage(stageId) - logDebug("After removal of stage %d, remaining stages = %d" .format(stageId, stageIdToStage.size)) } @@ -691,49 +771,83 @@ class DAGScheduler( } } - /** Called when stage's parents are available and we can now do its task. */ private def submitMissingTasks(stage: Stage, jobId: Int) { logDebug("submitMissingTasks(" + stage + ")") // Get our pending tasks and remember them in our pendingTasks entry stage.pendingTasks.clear() var tasks = ArrayBuffer[Task[_]]() + + val properties = if (jobIdToActiveJob.contains(jobId)) { + jobIdToActiveJob(stage.jobId).properties + } else { + // this stage will be assigned to "default" pool + null + } + + runningStages += stage + // SparkListenerStageSubmitted should be posted before testing whether tasks are + // serializable. If tasks are not serializable, a SparkListenerStageCompleted event + // will be posted, which should always come after a corresponding SparkListenerStageSubmitted + // event. + listenerBus.post(SparkListenerStageSubmitted(stage.info, properties)) + + // TODO: Maybe we can keep the taskBinary in Stage to avoid serializing it multiple times. + // Broadcasted binary for the task, used to dispatch tasks to executors. Note that we broadcast + // the serialized copy of the RDD and for each task we will deserialize it, which means each + // task gets a different copy of the RDD. This provides stronger isolation between tasks that + // might modify state of objects referenced in their closures. This is necessary in Hadoop + // where the JobConf/Configuration object is not thread-safe. + var taskBinary: Broadcast[Array[Byte]] = null + try { + // For ShuffleMapTask, serialize and broadcast (rdd, shuffleDep). + // For ResultTask, serialize and broadcast (rdd, func). + val taskBinaryBytes: Array[Byte] = + if (stage.isShuffleMap) { + closureSerializer.serialize((stage.rdd, stage.shuffleDep.get) : AnyRef).array() + } else { + closureSerializer.serialize((stage.rdd, stage.resultOfJob.get.func) : AnyRef).array() + } + taskBinary = sc.broadcast(taskBinaryBytes) + } catch { + // In the case of a failure during serialization, abort the stage. + case e: NotSerializableException => + abortStage(stage, "Task not serializable: " + e.toString) + runningStages -= stage + return + case NonFatal(e) => + abortStage(stage, s"Task serialization failed: $e\n${e.getStackTraceString}") + runningStages -= stage + return + } + if (stage.isShuffleMap) { for (p <- 0 until stage.numPartitions if stage.outputLocs(p) == Nil) { val locs = getPreferredLocs(stage.rdd, p) - tasks += new ShuffleMapTask(stage.id, stage.rdd, stage.shuffleDep.get, p, locs) + val part = stage.rdd.partitions(p) + tasks += new ShuffleMapTask(stage.id, taskBinary, part, locs) } } else { // This is a final stage; figure out its job's missing partitions val job = stage.resultOfJob.get for (id <- 0 until job.numPartitions if !job.finished(id)) { - val partition = job.partitions(id) - val locs = getPreferredLocs(stage.rdd, partition) - tasks += new ResultTask(stage.id, stage.rdd, job.func, partition, locs, id) + val p: Int = job.partitions(id) + val part = stage.rdd.partitions(p) + val locs = getPreferredLocs(stage.rdd, p) + tasks += new ResultTask(stage.id, taskBinary, part, locs, id) } } - val properties = if (jobIdToActiveJob.contains(jobId)) { - jobIdToActiveJob(stage.jobId).properties - } else { - // this stage will be assigned to "default" pool - null - } - if (tasks.size > 0) { - runningStages += stage - // SparkListenerStageSubmitted should be posted before testing whether tasks are - // serializable. If tasks are not serializable, a SparkListenerStageCompleted event - // will be posted, which should always come after a corresponding SparkListenerStageSubmitted - // event. - listenerBus.post(SparkListenerStageSubmitted(stage.info, properties)) - // Preemptively serialize a task to make sure it can be serialized. We are catching this // exception here because it would be fairly hard to catch the non-serializable exception // down the road, where we have several different implementations for local scheduler and // cluster schedulers. + // + // We've already serialized RDDs and closures in taskBinary, but here we check for all other + // objects such as Partition. try { - SparkEnv.get.closureSerializer.newInstance().serialize(tasks.head) + closureSerializer.serialize(tasks.head) } catch { case e: NotSerializableException => abortStage(stage, "Task not serializable: " + e.toString) @@ -752,6 +866,9 @@ class DAGScheduler( new TaskSet(tasks.toArray, stage.id, stage.newAttemptId(), stage.jobId, properties)) stage.info.submissionTime = Some(clock.getTime()) } else { + // Because we posted SparkListenerStageSubmitted earlier, we should post + // SparkListenerStageCompleted here in case there are no tasks to run. + listenerBus.post(SparkListenerStageCompleted(stage.info)) logDebug("Stage " + stage + " is actually done; %b %d %d".format( stage.isAvailable, stage.numAvailableOutputs, stage.numPartitions)) runningStages -= stage @@ -766,8 +883,14 @@ class DAGScheduler( val task = event.task val stageId = task.stageId val taskType = Utils.getFormattedClassName(task) - listenerBus.post(SparkListenerTaskEnd(stageId, taskType, event.reason, event.taskInfo, - event.taskMetrics)) + + // The success case is dealt with separately below, since we need to compute accumulator + // updates before posting. + if (event.reason != Success) { + listenerBus.post(SparkListenerTaskEnd(stageId, taskType, event.reason, event.taskInfo, + event.taskMetrics)) + } + if (!stageIdToStage.contains(task.stageId)) { // Skip all the actions if the stage has been cancelled. return @@ -787,9 +910,28 @@ class DAGScheduler( event.reason match { case Success => if (event.accumUpdates != null) { - // TODO: fail the stage if the accumulator update fails... - Accumulators.add(event.accumUpdates) // TODO: do this only if task wasn't resubmitted + try { + Accumulators.add(event.accumUpdates) + event.accumUpdates.foreach { case (id, partialValue) => + val acc = Accumulators.originals(id).asInstanceOf[Accumulable[Any, Any]] + // To avoid UI cruft, ignore cases where value wasn't updated + if (acc.name.isDefined && partialValue != acc.zero) { + val name = acc.name.get + val stringPartialValue = Accumulators.stringifyPartialValue(partialValue) + val stringValue = Accumulators.stringifyValue(acc.value) + stage.info.accumulables(id) = AccumulableInfo(id, name, stringValue) + event.taskInfo.accumulables += + AccumulableInfo(id, name, Some(stringPartialValue), stringValue) + } + } + } catch { + // If we see an exception during accumulator update, just log the error and move on. + case e: Exception => + logError(s"Failed to update accumulators for $task", e) + } } + listenerBus.post(SparkListenerTaskEnd(stageId, taskType, event.reason, event.taskInfo, + event.taskMetrics)) stage.pendingTasks -= task task match { case rt: ResultTask[_, _] => @@ -1063,6 +1205,9 @@ class DAGScheduler( } val visitedRdds = new HashSet[RDD[_]] val visitedStages = new HashSet[Stage] + // We are manually maintaining a stack here to prevent StackOverflowError + // caused by recursively visiting + val waitingForVisit = new Stack[RDD[_]] def visit(rdd: RDD[_]) { if (!visitedRdds(rdd)) { visitedRdds += rdd @@ -1072,15 +1217,18 @@ class DAGScheduler( val mapStage = getShuffleMapStage(shufDep, stage.jobId) if (!mapStage.isAvailable) { visitedStages += mapStage - visit(mapStage.rdd) + waitingForVisit.push(mapStage.rdd) } // Otherwise there's no need to follow the dependency back case narrowDep: NarrowDependency[_] => - visit(narrowDep.rdd) + waitingForVisit.push(narrowDep.rdd) } } } } - visit(stage.rdd) + waitingForVisit.push(stage.rdd) + while (!waitingForVisit.isEmpty) { + visit(waitingForVisit.pop()) + } visitedRdds.contains(target.rdd) } @@ -1092,6 +1240,22 @@ class DAGScheduler( */ private[spark] def getPreferredLocs(rdd: RDD[_], partition: Int): Seq[TaskLocation] = synchronized { + getPreferredLocsInternal(rdd, partition, new HashSet) + } + + /** Recursive implementation for getPreferredLocs. */ + private def getPreferredLocsInternal( + rdd: RDD[_], + partition: Int, + visited: HashSet[(RDD[_],Int)]) + : Seq[TaskLocation] = + { + // If the partition has already been visited, no need to re-visit. + // This avoids exponential path exploration. SPARK-695 + if (!visited.add((rdd,partition))) { + // Nil has already been returned for previously visited partitions. + return Nil + } // If the partition is cached, return the cache locations val cached = getCacheLocs(rdd)(partition) if (!cached.isEmpty) { @@ -1108,7 +1272,7 @@ class DAGScheduler( rdd.dependencies.foreach { case n: NarrowDependency[_] => for (inPart <- n.getParents(partition)) { - val locs = getPreferredLocs(n.rdd, inPart) + val locs = getPreferredLocsInternal(n.rdd, inPart, visited) if (locs != Nil) { return locs } diff --git a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala index ae6ca9f4e7bf5..406147f167bf3 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala @@ -29,7 +29,7 @@ import org.json4s.jackson.JsonMethods._ import org.apache.spark.{Logging, SparkConf, SparkContext} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.io.CompressionCodec -import org.apache.spark.util.{FileLogger, JsonProtocol} +import org.apache.spark.util.{FileLogger, JsonProtocol, Utils} /** * A SparkListener that logs events to persistent storage. @@ -55,7 +55,7 @@ private[spark] class EventLoggingListener( private val outputBufferSize = sparkConf.getInt("spark.eventLog.buffer.kb", 100) * 1024 private val logBaseDir = sparkConf.get("spark.eventLog.dir", DEFAULT_LOG_DIR).stripSuffix("/") private val name = appName.replaceAll("[ :/]", "-").toLowerCase + "-" + System.currentTimeMillis - val logDir = logBaseDir + "/" + name + val logDir = Utils.resolveURI(logBaseDir) + "/" + name.stripSuffix("/") protected val logger = new FileLogger(logDir, sparkConf, hadoopConf, outputBufferSize, shouldCompress, shouldOverwrite, Some(LOG_FILE_PERMISSIONS)) @@ -215,7 +215,7 @@ private[spark] object EventLoggingListener extends Logging { } catch { case e: Exception => logError("Exception in parsing logging info from directory %s".format(logDir), e) - EventLoggingInfo.empty + EventLoggingInfo.empty } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala index bbf9f7388b074..d09fd7aa57642 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala @@ -17,134 +17,56 @@ package org.apache.spark.scheduler -import scala.language.existentials +import java.nio.ByteBuffer import java.io._ -import java.util.zip.{GZIPInputStream, GZIPOutputStream} - -import scala.collection.mutable.HashMap import org.apache.spark._ -import org.apache.spark.rdd.{RDD, RDDCheckpointData} - -private[spark] object ResultTask { - - // A simple map between the stage id to the serialized byte array of a task. - // Served as a cache for task serialization because serialization can be - // expensive on the master node if it needs to launch thousands of tasks. - private val serializedInfoCache = new HashMap[Int, Array[Byte]] - - def serializeInfo(stageId: Int, rdd: RDD[_], func: (TaskContext, Iterator[_]) => _): Array[Byte] = - { - synchronized { - val old = serializedInfoCache.get(stageId).orNull - if (old != null) { - old - } else { - val out = new ByteArrayOutputStream - val ser = SparkEnv.get.closureSerializer.newInstance() - val objOut = ser.serializeStream(new GZIPOutputStream(out)) - objOut.writeObject(rdd) - objOut.writeObject(func) - objOut.close() - val bytes = out.toByteArray - serializedInfoCache.put(stageId, bytes) - bytes - } - } - } - - def deserializeInfo(stageId: Int, bytes: Array[Byte]): (RDD[_], (TaskContext, Iterator[_]) => _) = - { - val in = new GZIPInputStream(new ByteArrayInputStream(bytes)) - val ser = SparkEnv.get.closureSerializer.newInstance() - val objIn = ser.deserializeStream(in) - val rdd = objIn.readObject().asInstanceOf[RDD[_]] - val func = objIn.readObject().asInstanceOf[(TaskContext, Iterator[_]) => _] - (rdd, func) - } - - def removeStage(stageId: Int) { - serializedInfoCache.remove(stageId) - } - - def clearCache() { - synchronized { - serializedInfoCache.clear() - } - } -} - +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.rdd.RDD /** * A task that sends back the output to the driver application. * - * See [[org.apache.spark.scheduler.Task]] for more information. + * See [[Task]] for more information. * * @param stageId id of the stage this task belongs to - * @param rdd input to func - * @param func a function to apply on a partition of the RDD - * @param _partitionId index of the number in the RDD + * @param taskBinary broadcasted version of the serialized RDD and the function to apply on each + * partition of the given RDD. Once deserialized, the type should be + * (RDD[T], (TaskContext, Iterator[T]) => U). + * @param partition partition of the RDD this task is associated with * @param locs preferred task execution locations for locality scheduling * @param outputId index of the task in this job (a job can launch tasks on only a subset of the * input RDD's partitions). */ private[spark] class ResultTask[T, U]( stageId: Int, - var rdd: RDD[T], - var func: (TaskContext, Iterator[T]) => U, - _partitionId: Int, + taskBinary: Broadcast[Array[Byte]], + partition: Partition, @transient locs: Seq[TaskLocation], - var outputId: Int) - extends Task[U](stageId, _partitionId) with Externalizable { + val outputId: Int) + extends Task[U](stageId, partition.index) with Serializable { - def this() = this(0, null, null, 0, null, 0) - - var split = if (rdd == null) null else rdd.partitions(partitionId) - - @transient private val preferredLocs: Seq[TaskLocation] = { + @transient private[this] val preferredLocs: Seq[TaskLocation] = { if (locs == null) Nil else locs.toSet.toSeq } override def runTask(context: TaskContext): U = { + // Deserialize the RDD and the func using the broadcast variables. + val ser = SparkEnv.get.closureSerializer.newInstance() + val (rdd, func) = ser.deserialize[(RDD[T], (TaskContext, Iterator[T]) => U)]( + ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader) + metrics = Some(context.taskMetrics) try { - func(context, rdd.iterator(split, context)) + func(context, rdd.iterator(partition, context)) } finally { context.executeOnCompleteCallbacks() } } + // This is only callable on the driver side. override def preferredLocations: Seq[TaskLocation] = preferredLocs override def toString = "ResultTask(" + stageId + ", " + partitionId + ")" - - override def writeExternal(out: ObjectOutput) { - RDDCheckpointData.synchronized { - split = rdd.partitions(partitionId) - out.writeInt(stageId) - val bytes = ResultTask.serializeInfo( - stageId, rdd, func.asInstanceOf[(TaskContext, Iterator[_]) => _]) - out.writeInt(bytes.length) - out.write(bytes) - out.writeInt(partitionId) - out.writeInt(outputId) - out.writeLong(epoch) - out.writeObject(split) - } - } - - override def readExternal(in: ObjectInput) { - val stageId = in.readInt() - val numBytes = in.readInt() - val bytes = new Array[Byte](numBytes) - in.readFully(bytes) - val (rdd_, func_) = ResultTask.deserializeInfo(stageId, bytes) - rdd = rdd_.asInstanceOf[RDD[T]] - func = func_.asInstanceOf[(TaskContext, Iterator[T]) => U] - partitionId = in.readInt() - outputId = in.readInt() - epoch = in.readLong() - split = in.readObject().asInstanceOf[Partition] - } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala index fdaf1de83f051..11255c07469d4 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala @@ -17,134 +17,55 @@ package org.apache.spark.scheduler -import scala.language.existentials - -import java.io._ -import java.util.zip.{GZIPInputStream, GZIPOutputStream} +import java.nio.ByteBuffer -import scala.collection.mutable.HashMap +import scala.language.existentials import org.apache.spark._ -import org.apache.spark.rdd.{RDD, RDDCheckpointData} +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.rdd.RDD import org.apache.spark.shuffle.ShuffleWriter -private[spark] object ShuffleMapTask { - - // A simple map between the stage id to the serialized byte array of a task. - // Served as a cache for task serialization because serialization can be - // expensive on the master node if it needs to launch thousands of tasks. - private val serializedInfoCache = new HashMap[Int, Array[Byte]] - - def serializeInfo(stageId: Int, rdd: RDD[_], dep: ShuffleDependency[_, _, _]): Array[Byte] = { - synchronized { - val old = serializedInfoCache.get(stageId).orNull - if (old != null) { - return old - } else { - val out = new ByteArrayOutputStream - val ser = SparkEnv.get.closureSerializer.newInstance() - val objOut = ser.serializeStream(new GZIPOutputStream(out)) - objOut.writeObject(rdd) - objOut.writeObject(dep) - objOut.close() - val bytes = out.toByteArray - serializedInfoCache.put(stageId, bytes) - bytes - } - } - } - - def deserializeInfo(stageId: Int, bytes: Array[Byte]): (RDD[_], ShuffleDependency[_, _, _]) = { - val in = new GZIPInputStream(new ByteArrayInputStream(bytes)) - val ser = SparkEnv.get.closureSerializer.newInstance() - val objIn = ser.deserializeStream(in) - val rdd = objIn.readObject().asInstanceOf[RDD[_]] - val dep = objIn.readObject().asInstanceOf[ShuffleDependency[_, _, _]] - (rdd, dep) - } - - // Since both the JarSet and FileSet have the same format this is used for both. - def deserializeFileSet(bytes: Array[Byte]): HashMap[String, Long] = { - val in = new GZIPInputStream(new ByteArrayInputStream(bytes)) - val objIn = new ObjectInputStream(in) - val set = objIn.readObject().asInstanceOf[Array[(String, Long)]].toMap - HashMap(set.toSeq: _*) - } - - def removeStage(stageId: Int) { - serializedInfoCache.remove(stageId) - } - - def clearCache() { - synchronized { - serializedInfoCache.clear() - } - } -} - /** - * A ShuffleMapTask divides the elements of an RDD into multiple buckets (based on a partitioner - * specified in the ShuffleDependency). - * - * See [[org.apache.spark.scheduler.Task]] for more information. - * +* A ShuffleMapTask divides the elements of an RDD into multiple buckets (based on a partitioner +* specified in the ShuffleDependency). +* +* See [[org.apache.spark.scheduler.Task]] for more information. +* * @param stageId id of the stage this task belongs to - * @param rdd the final RDD in this stage - * @param dep the ShuffleDependency - * @param _partitionId index of the number in the RDD + * @param taskBinary broadcast version of of the RDD and the ShuffleDependency. Once deserialized, + * the type should be (RDD[_], ShuffleDependency[_, _, _]). + * @param partition partition of the RDD this task is associated with * @param locs preferred task execution locations for locality scheduling */ private[spark] class ShuffleMapTask( stageId: Int, - var rdd: RDD[_], - var dep: ShuffleDependency[_, _, _], - _partitionId: Int, + taskBinary: Broadcast[Array[Byte]], + partition: Partition, @transient private var locs: Seq[TaskLocation]) - extends Task[MapStatus](stageId, _partitionId) - with Externalizable - with Logging { + extends Task[MapStatus](stageId, partition.index) with Logging { - protected def this() = this(0, null, null, 0, null) + /** A constructor used only in test suites. This does not require passing in an RDD. */ + def this(partitionId: Int) { + this(0, null, new Partition { override def index = 0 }, null) + } @transient private val preferredLocs: Seq[TaskLocation] = { if (locs == null) Nil else locs.toSet.toSeq } - var split = if (rdd == null) null else rdd.partitions(partitionId) - - override def writeExternal(out: ObjectOutput) { - RDDCheckpointData.synchronized { - split = rdd.partitions(partitionId) - out.writeInt(stageId) - val bytes = ShuffleMapTask.serializeInfo(stageId, rdd, dep) - out.writeInt(bytes.length) - out.write(bytes) - out.writeInt(partitionId) - out.writeLong(epoch) - out.writeObject(split) - } - } - - override def readExternal(in: ObjectInput) { - val stageId = in.readInt() - val numBytes = in.readInt() - val bytes = new Array[Byte](numBytes) - in.readFully(bytes) - val (rdd_, dep_) = ShuffleMapTask.deserializeInfo(stageId, bytes) - rdd = rdd_ - dep = dep_ - partitionId = in.readInt() - epoch = in.readLong() - split = in.readObject().asInstanceOf[Partition] - } - override def runTask(context: TaskContext): MapStatus = { + // Deserialize the RDD using the broadcast variable. + val ser = SparkEnv.get.closureSerializer.newInstance() + val (rdd, dep) = ser.deserialize[(RDD[_], ShuffleDependency[_, _, _])]( + ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader) + metrics = Some(context.taskMetrics) var writer: ShuffleWriter[Any, Any] = null try { val manager = SparkEnv.get.shuffleManager writer = manager.getWriter[Any, Any](dep.shuffleHandle, partitionId, context) - writer.write(rdd.iterator(split, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]]) + writer.write(rdd.iterator(partition, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]]) return writer.stop(success = true).get } catch { case e: Exception => diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala index 82163eadd56e9..d01d318633877 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala @@ -75,6 +75,12 @@ case class SparkListenerBlockManagerRemoved(blockManagerId: BlockManagerId) @DeveloperApi case class SparkListenerUnpersistRDD(rddId: Int) extends SparkListenerEvent +@DeveloperApi +case class SparkListenerExecutorMetricsUpdate( + execId: String, + taskMetrics: Seq[(Long, Int, TaskMetrics)]) + extends SparkListenerEvent + @DeveloperApi case class SparkListenerApplicationStart(appName: String, time: Long, sparkUser: String) extends SparkListenerEvent @@ -158,6 +164,11 @@ trait SparkListener { * Called when the application ends */ def onApplicationEnd(applicationEnd: SparkListenerApplicationEnd) { } + + /** + * Called when the driver receives task metrics from an executor in a heartbeat. + */ + def onExecutorMetricsUpdate(executorMetricsUpdate: SparkListenerExecutorMetricsUpdate) { } } /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala index ed9fb24bc8ce8..e79ffd7a3587d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala @@ -68,6 +68,8 @@ private[spark] trait SparkListenerBus extends Logging { foreachListener(_.onApplicationStart(applicationStart)) case applicationEnd: SparkListenerApplicationEnd => foreachListener(_.onApplicationEnd(applicationEnd)) + case metricsUpdate: SparkListenerExecutorMetricsUpdate => + foreachListener(_.onExecutorMetricsUpdate(metricsUpdate)) case SparkListenerShutdown => } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala index 480891550eb60..2a407e47a05bd 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala @@ -17,6 +17,8 @@ package org.apache.spark.scheduler +import scala.collection.mutable.HashMap + import org.apache.spark.annotation.DeveloperApi import org.apache.spark.storage.RDDInfo @@ -37,6 +39,8 @@ class StageInfo( var completionTime: Option[Long] = None /** If the stage failed, the reason why. */ var failureReason: Option[String] = None + /** Terminal values of accumulables updated during this stage. */ + val accumulables = HashMap[Long, AccumulableInfo]() def stageFailed(reason: String) { failureReason = Some(reason) diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 5871edeb856ad..5c5e421404a21 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -26,6 +26,8 @@ import org.apache.spark.TaskContext import org.apache.spark.executor.TaskMetrics import org.apache.spark.serializer.SerializerInstance import org.apache.spark.util.ByteBufferInputStream +import org.apache.spark.util.Utils + /** * A unit of execution. We have two kinds of Task's in Spark: @@ -44,6 +46,7 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex final def run(attemptId: Long): T = { context = new TaskContext(stageId, partitionId, attemptId, runningLocally = false) + context.taskMetrics.hostname = Utils.localHostName(); taskThread = Thread.currentThread() if (_killed) { kill(interruptThread = false) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala index ca0595f35143e..6fa1f2c880f7a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala @@ -17,6 +17,8 @@ package org.apache.spark.scheduler +import scala.collection.mutable.ListBuffer + import org.apache.spark.annotation.DeveloperApi /** @@ -41,6 +43,13 @@ class TaskInfo( */ var gettingResultTime: Long = 0 + /** + * Intermediate updates to accumulables during this task. Note that it is valid for the same + * accumulable to be updated multiple times in a single task or for two accumulables with the + * same name but different IDs to exist in a task. + */ + val accumulables = ListBuffer[AccumulableInfo]() + /** * The time when the task has completed successfully (including the time to remotely fetch * results, if necessary). diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskLocality.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskLocality.scala index eb920ab0c0b67..f176d09816f5e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskLocality.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskLocality.scala @@ -22,7 +22,7 @@ import org.apache.spark.annotation.DeveloperApi @DeveloperApi object TaskLocality extends Enumeration { // Process local is expected to be used ONLY within TaskSetManager for now. - val PROCESS_LOCAL, NODE_LOCAL, RACK_LOCAL, ANY = Value + val PROCESS_LOCAL, NODE_LOCAL, NO_PREF, RACK_LOCAL, ANY = Value type TaskLocality = Value diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala index 819c35257b5a7..1a0b877c8a5e1 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala @@ -18,6 +18,8 @@ package org.apache.spark.scheduler import org.apache.spark.scheduler.SchedulingMode.SchedulingMode +import org.apache.spark.executor.TaskMetrics +import org.apache.spark.storage.BlockManagerId /** * Low-level task scheduler interface, currently implemented exclusively by TaskSchedulerImpl. @@ -54,4 +56,12 @@ private[spark] trait TaskScheduler { // Get the default level of parallelism to use in the cluster, as a hint for sizing jobs. def defaultParallelism(): Int + + /** + * Update metrics for in-progress tasks and let the master know that the BlockManager is still + * alive. Return true if the driver knows about the given block manager. Otherwise, return false, + * indicating that the block manager should re-register. + */ + def executorHeartbeatReceived(execId: String, taskMetrics: Array[(Long, TaskMetrics)], + blockManagerId: BlockManagerId): Boolean } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index be3673c48eda8..6c0d1b2752a81 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -32,6 +32,9 @@ import org.apache.spark._ import org.apache.spark.TaskState.TaskState import org.apache.spark.scheduler.SchedulingMode.SchedulingMode import org.apache.spark.util.Utils +import org.apache.spark.executor.TaskMetrics +import org.apache.spark.storage.BlockManagerId +import akka.actor.Props /** * Schedules tasks for multiple types of clusters by acting through a SchedulerBackend. @@ -86,11 +89,11 @@ private[spark] class TaskSchedulerImpl( // The set of executors we have on each host; this is used to compute hostsAlive, which // in turn is used to decide when we can attain data locality on a given host - private val executorsByHost = new HashMap[String, HashSet[String]] + protected val executorsByHost = new HashMap[String, HashSet[String]] protected val hostsByRack = new HashMap[String, HashSet[String]] - private val executorIdToHost = new HashMap[String, String] + protected val executorIdToHost = new HashMap[String, String] // Listener object to pass upcalls into var dagScheduler: DAGScheduler = null @@ -246,6 +249,7 @@ private[spark] class TaskSchedulerImpl( // Take each TaskSet in our scheduling order, and then offer it each node in increasing order // of locality levels so that it gets a chance to launch local tasks on all of them. + // NOTE: the preferredLocality order: PROCESS_LOCAL, NODE_LOCAL, NO_PREF, RACK_LOCAL, ANY var launchedTask = false for (taskSet <- sortedTaskSets; maxLocality <- taskSet.myLocalityLevels) { do { @@ -262,7 +266,7 @@ private[spark] class TaskSchedulerImpl( activeExecutorIds += execId executorsByHost(host) += execId availableCpus(i) -= CPUS_PER_TASK - assert (availableCpus(i) >= 0) + assert(availableCpus(i) >= 0) launchedTask = true } } @@ -320,6 +324,26 @@ private[spark] class TaskSchedulerImpl( } } + /** + * Update metrics for in-progress tasks and let the master know that the BlockManager is still + * alive. Return true if the driver knows about the given block manager. Otherwise, return false, + * indicating that the block manager should re-register. + */ + override def executorHeartbeatReceived( + execId: String, + taskMetrics: Array[(Long, TaskMetrics)], // taskId -> TaskMetrics + blockManagerId: BlockManagerId): Boolean = { + val metricsWithStageIds = taskMetrics.flatMap { + case (id, metrics) => { + taskIdToTaskSetId.get(id) + .flatMap(activeTaskSets.get) + .map(_.stageId) + .map(x => (id, x, metrics)) + } + } + dagScheduler.executorHeartbeatReceived(execId, metricsWithStageIds, blockManagerId) + } + def handleTaskGettingResult(taskSetManager: TaskSetManager, tid: Long) { taskSetManager.handleTaskGettingResult(tid) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 8b5e8cb802a45..20a4bd12f93f6 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -79,6 +79,7 @@ private[spark] class TaskSetManager( private val numFailures = new Array[Int](numTasks) // key is taskId, value is a Map of executor id to when it failed private val failedExecutors = new HashMap[Int, HashMap[String, Long]]() + val taskAttempts = Array.fill[List[TaskInfo]](numTasks)(Nil) var tasksSuccessful = 0 @@ -179,26 +180,17 @@ private[spark] class TaskSetManager( } } - var hadAliveLocations = false for (loc <- tasks(index).preferredLocations) { for (execId <- loc.executorId) { addTo(pendingTasksForExecutor.getOrElseUpdate(execId, new ArrayBuffer)) } - if (sched.hasExecutorsAliveOnHost(loc.host)) { - hadAliveLocations = true - } addTo(pendingTasksForHost.getOrElseUpdate(loc.host, new ArrayBuffer)) for (rack <- sched.getRackForHost(loc.host)) { addTo(pendingTasksForRack.getOrElseUpdate(rack, new ArrayBuffer)) - if(sched.hasHostAliveOnRack(rack)){ - hadAliveLocations = true - } } } - if (!hadAliveLocations) { - // Even though the task might've had preferred locations, all of those hosts or executors - // are dead; put it in the no-prefs list so we can schedule it elsewhere right away. + if (tasks(index).preferredLocations == Nil) { addTo(pendingTasksWithNoPrefs) } @@ -239,7 +231,6 @@ private[spark] class TaskSetManager( */ private def findTaskFromList(execId: String, list: ArrayBuffer[Int]): Option[Int] = { var indexOffset = list.size - while (indexOffset > 0) { indexOffset -= 1 val index = list(indexOffset) @@ -288,12 +279,12 @@ private[spark] class TaskSetManager( !hasAttemptOnHost(index, host) && !executorIsBlacklisted(execId, index) if (!speculatableTasks.isEmpty) { - // Check for process-local or preference-less tasks; note that tasks can be process-local + // Check for process-local tasks; note that tasks can be process-local // on multiple nodes when we replicate cached blocks, as in Spark Streaming for (index <- speculatableTasks if canRunOnHost(index)) { val prefs = tasks(index).preferredLocations val executors = prefs.flatMap(_.executorId) - if (prefs.size == 0 || executors.contains(execId)) { + if (executors.contains(execId)) { speculatableTasks -= index return Some((index, TaskLocality.PROCESS_LOCAL)) } @@ -310,6 +301,17 @@ private[spark] class TaskSetManager( } } + // Check for no-preference tasks + if (TaskLocality.isAllowed(locality, TaskLocality.NO_PREF)) { + for (index <- speculatableTasks if canRunOnHost(index)) { + val locations = tasks(index).preferredLocations + if (locations.size == 0) { + speculatableTasks -= index + return Some((index, TaskLocality.PROCESS_LOCAL)) + } + } + } + // Check for rack-local tasks if (TaskLocality.isAllowed(locality, TaskLocality.RACK_LOCAL)) { for (rack <- sched.getRackForHost(host)) { @@ -341,20 +343,27 @@ private[spark] class TaskSetManager( * * @return An option containing (task index within the task set, locality, is speculative?) */ - private def findTask(execId: String, host: String, locality: TaskLocality.Value) + private def findTask(execId: String, host: String, maxLocality: TaskLocality.Value) : Option[(Int, TaskLocality.Value, Boolean)] = { for (index <- findTaskFromList(execId, getPendingTasksForExecutor(execId))) { return Some((index, TaskLocality.PROCESS_LOCAL, false)) } - if (TaskLocality.isAllowed(locality, TaskLocality.NODE_LOCAL)) { + if (TaskLocality.isAllowed(maxLocality, TaskLocality.NODE_LOCAL)) { for (index <- findTaskFromList(execId, getPendingTasksForHost(host))) { return Some((index, TaskLocality.NODE_LOCAL, false)) } } - if (TaskLocality.isAllowed(locality, TaskLocality.RACK_LOCAL)) { + if (TaskLocality.isAllowed(maxLocality, TaskLocality.NO_PREF)) { + // Look for noPref tasks after NODE_LOCAL for minimize cross-rack traffic + for (index <- findTaskFromList(execId, pendingTasksWithNoPrefs)) { + return Some((index, TaskLocality.PROCESS_LOCAL, false)) + } + } + + if (TaskLocality.isAllowed(maxLocality, TaskLocality.RACK_LOCAL)) { for { rack <- sched.getRackForHost(host) index <- findTaskFromList(execId, getPendingTasksForRack(rack)) @@ -363,25 +372,27 @@ private[spark] class TaskSetManager( } } - // Look for no-pref tasks after rack-local tasks since they can run anywhere. - for (index <- findTaskFromList(execId, pendingTasksWithNoPrefs)) { - return Some((index, TaskLocality.PROCESS_LOCAL, false)) - } - - if (TaskLocality.isAllowed(locality, TaskLocality.ANY)) { + if (TaskLocality.isAllowed(maxLocality, TaskLocality.ANY)) { for (index <- findTaskFromList(execId, allPendingTasks)) { return Some((index, TaskLocality.ANY, false)) } } - // Finally, if all else has failed, find a speculative task - findSpeculativeTask(execId, host, locality).map { case (taskIndex, allowedLocality) => - (taskIndex, allowedLocality, true) - } + // find a speculative task if all others tasks have been scheduled + findSpeculativeTask(execId, host, maxLocality).map { + case (taskIndex, allowedLocality) => (taskIndex, allowedLocality, true)} } /** * Respond to an offer of a single executor from the scheduler by finding a task + * + * NOTE: this function is either called with a maxLocality which + * would be adjusted by delay scheduling algorithm or it will be with a special + * NO_PREF locality which will be not modified + * + * @param execId the executor Id of the offered resource + * @param host the host Id of the offered resource + * @param maxLocality the maximum locality we want to schedule the tasks at */ def resourceOffer( execId: String, @@ -392,9 +403,14 @@ private[spark] class TaskSetManager( if (!isZombie) { val curTime = clock.getTime() - var allowedLocality = getAllowedLocalityLevel(curTime) - if (allowedLocality > maxLocality) { - allowedLocality = maxLocality // We're not allowed to search for farther-away tasks + var allowedLocality = maxLocality + + if (maxLocality != TaskLocality.NO_PREF) { + allowedLocality = getAllowedLocalityLevel(curTime) + if (allowedLocality > maxLocality) { + // We're not allowed to search for farther-away tasks + allowedLocality = maxLocality + } } findTask(execId, host, allowedLocality) match { @@ -410,8 +426,11 @@ private[spark] class TaskSetManager( taskInfos(taskId) = info taskAttempts(index) = info :: taskAttempts(index) // Update our locality level for delay scheduling - currentLocalityIndex = getLocalityIndex(taskLocality) - lastLaunchTime = curTime + // NO_PREF will not affect the variables related to delay scheduling + if (maxLocality != TaskLocality.NO_PREF) { + currentLocalityIndex = getLocalityIndex(taskLocality) + lastLaunchTime = curTime + } // Serialize and return the task val startTime = clock.getTime() // We rely on the DAGScheduler to catch non-serializable closures and RDDs, so in here @@ -639,8 +658,7 @@ private[spark] class TaskSetManager( override def executorLost(execId: String, host: String) { logInfo("Re-queueing tasks for " + execId + " from TaskSet " + taskSet.id) - // Re-enqueue pending tasks for this host based on the status of the cluster -- for example, a - // task that used to have locations on only this host might now go to the no-prefs list. Note + // Re-enqueue pending tasks for this host based on the status of the cluster. Note // that it's okay if we add a task to the same queue twice (if it had multiple preferred // locations), because findTaskFromList will skip already-running tasks. for (index <- getPendingTasksForExecutor(execId)) { @@ -671,6 +689,9 @@ private[spark] class TaskSetManager( for ((tid, info) <- taskInfos if info.running && info.executorId == execId) { handleFailedTask(tid, TaskState.FAILED, ExecutorLostFailure) } + // recalculate valid locality levels and waits when executor is lost + myLocalityLevels = computeValidLocalityLevels() + localityWaits = myLocalityLevels.map(getLocalityWait) } /** @@ -722,17 +743,17 @@ private[spark] class TaskSetManager( conf.get("spark.locality.wait.node", defaultWait).toLong case TaskLocality.RACK_LOCAL => conf.get("spark.locality.wait.rack", defaultWait).toLong - case TaskLocality.ANY => - 0L + case _ => 0L } } /** * Compute the locality levels used in this TaskSet. Assumes that all tasks have already been * added to queues using addPendingTask. + * */ private def computeValidLocalityLevels(): Array[TaskLocality.TaskLocality] = { - import TaskLocality.{PROCESS_LOCAL, NODE_LOCAL, RACK_LOCAL, ANY} + import TaskLocality.{PROCESS_LOCAL, NODE_LOCAL, NO_PREF, RACK_LOCAL, ANY} val levels = new ArrayBuffer[TaskLocality.TaskLocality] if (!pendingTasksForExecutor.isEmpty && getLocalityWait(PROCESS_LOCAL) != 0 && pendingTasksForExecutor.keySet.exists(sched.isExecutorAlive(_))) { @@ -742,6 +763,9 @@ private[spark] class TaskSetManager( pendingTasksForHost.keySet.exists(sched.hasExecutorsAliveOnHost(_))) { levels += NODE_LOCAL } + if (!pendingTasksWithNoPrefs.isEmpty) { + levels += NO_PREF + } if (!pendingTasksForRack.isEmpty && getLocalityWait(RACK_LOCAL) != 0 && pendingTasksForRack.keySet.exists(sched.hasHostAliveOnRack(_))) { levels += RACK_LOCAL @@ -751,20 +775,7 @@ private[spark] class TaskSetManager( levels.toArray } - // Re-compute pendingTasksWithNoPrefs since new preferred locations may become available def executorAdded() { - def newLocAvail(index: Int): Boolean = { - for (loc <- tasks(index).preferredLocations) { - if (sched.hasExecutorsAliveOnHost(loc.host) || - (sched.getRackForHost(loc.host).isDefined && - sched.hasHostAliveOnRack(sched.getRackForHost(loc.host).get))) { - return true - } - } - false - } - logInfo("Re-computing pending task lists.") - pendingTasksWithNoPrefs = pendingTasksWithNoPrefs.filter(!newLocAvail(_)) myLocalityLevels = computeValidLocalityLevels() localityWaits = myLocalityLevels.map(getLocalityWait) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index bf2dc88e29048..a28446f6c8a6b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -17,7 +17,7 @@ package org.apache.spark.scheduler.cluster -import org.apache.spark.{Logging, SparkContext} +import org.apache.spark.{Logging, SparkConf, SparkContext} import org.apache.spark.deploy.{ApplicationDescription, Command} import org.apache.spark.deploy.client.{AppClient, AppClientListener} import org.apache.spark.scheduler.{ExecutorExited, ExecutorLossReason, SlaveLost, TaskSchedulerImpl} @@ -46,6 +46,7 @@ private[spark] class SparkDeploySchedulerBackend( CoarseGrainedSchedulerBackend.ACTOR_NAME) val args = Seq(driverUrl, "{{EXECUTOR_ID}}", "{{HOSTNAME}}", "{{CORES}}", "{{WORKER_URL}}") val extraJavaOpts = sc.conf.getOption("spark.executor.extraJavaOptions") + .map(Utils.splitCommandString).getOrElse(Seq.empty) val classPathEntries = sc.conf.getOption("spark.executor.extraClassPath").toSeq.flatMap { cp => cp.split(java.io.File.pathSeparator) } @@ -54,12 +55,13 @@ private[spark] class SparkDeploySchedulerBackend( cp.split(java.io.File.pathSeparator) } - val command = Command( - "org.apache.spark.executor.CoarseGrainedExecutorBackend", args, sc.executorEnvs, - classPathEntries, libraryPathEntries, extraJavaOpts) - val sparkHome = sc.getSparkHome() + // Start executors with a few necessary configs for registering with the scheduler + val sparkJavaOpts = Utils.sparkJavaOpts(conf, SparkConf.isExecutorStartupConf) + val javaOpts = sparkJavaOpts ++ extraJavaOpts + val command = Command("org.apache.spark.executor.CoarseGrainedExecutorBackend", + args, sc.executorEnvs, classPathEntries, libraryPathEntries, javaOpts) val appDesc = new ApplicationDescription(sc.appName, maxCores, sc.executorMemory, command, - sparkHome, sc.ui.appUIAddress, sc.eventLogger.map(_.logDir)) + sc.ui.appUIAddress, sc.eventLogger.map(_.logDir)) client = new AppClient(sc.env.actorSystem, masters, appDesc, this, conf) client.start() diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala index 5b897597fa285..3d1cf312ccc97 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala @@ -23,8 +23,9 @@ import akka.actor.{Actor, ActorRef, Props} import org.apache.spark.{Logging, SparkEnv, TaskState} import org.apache.spark.TaskState.TaskState -import org.apache.spark.executor.{Executor, ExecutorBackend} +import org.apache.spark.executor.{TaskMetrics, Executor, ExecutorBackend} import org.apache.spark.scheduler.{SchedulerBackend, TaskSchedulerImpl, WorkerOffer} +import org.apache.spark.storage.BlockManagerId private case class ReviveOffers() @@ -32,6 +33,8 @@ private case class StatusUpdate(taskId: Long, state: TaskState, serializedData: private case class KillTask(taskId: Long, interruptThread: Boolean) +private case class StopExecutor() + /** * Calls to LocalBackend are all serialized through LocalActor. Using an actor makes the calls on * LocalBackend asynchronous, which is necessary to prevent deadlock between LocalBackend @@ -63,6 +66,9 @@ private[spark] class LocalActor( case KillTask(taskId, interruptThread) => executor.killTask(taskId, interruptThread) + + case StopExecutor => + executor.stop() } def reviveOffers() { @@ -91,6 +97,7 @@ private[spark] class LocalBackend(scheduler: TaskSchedulerImpl, val totalCores: } override def stop() { + localActor ! StopExecutor } override def reviveOffers() { diff --git a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala index a7fa057ee05f7..34bc3124097bb 100644 --- a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala @@ -35,16 +35,15 @@ private[spark] class JavaSerializationStream(out: OutputStream, counterReset: In /** * Calling reset to avoid memory leak: * http://stackoverflow.com/questions/1281549/memory-leak-traps-in-the-java-standard-api - * But only call it every 10,000th time to avoid bloated serialization streams (when + * But only call it every 100th time to avoid bloated serialization streams (when * the stream 'resets' object class descriptions have to be re-written) */ def writeObject[T: ClassTag](t: T): SerializationStream = { objOut.writeObject(t) + counter += 1 if (counterReset > 0 && counter >= counterReset) { objOut.reset() counter = 0 - } else { - counter += 1 } this } diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index fa79b25759153..407cb9db6ee9a 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -47,12 +47,15 @@ class KryoSerializer(conf: SparkConf) with Logging with Serializable { - private val bufferSize = conf.getInt("spark.kryoserializer.buffer.mb", 2) * 1024 * 1024 + private val bufferSize = + (conf.getDouble("spark.kryoserializer.buffer.mb", 0.064) * 1024 * 1024).toInt + + private val maxBufferSize = conf.getInt("spark.kryoserializer.buffer.max.mb", 64) * 1024 * 1024 private val referenceTracking = conf.getBoolean("spark.kryo.referenceTracking", true) private val registrationRequired = conf.getBoolean("spark.kryo.registrationRequired", false) private val registrator = conf.getOption("spark.kryo.registrator") - def newKryoOutput() = new KryoOutput(bufferSize) + def newKryoOutput() = new KryoOutput(bufferSize, math.max(bufferSize, maxBufferSize)) def newKryo(): Kryo = { val instantiator = new EmptyScalaKryoInstantiator diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala new file mode 100644 index 0000000000000..ee91a368b76ea --- /dev/null +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle + +import scala.collection.mutable + +import org.apache.spark.{Logging, SparkException, SparkConf} + +/** + * Allocates a pool of memory to task threads for use in shuffle operations. Each disk-spilling + * collection (ExternalAppendOnlyMap or ExternalSorter) used by these tasks can acquire memory + * from this pool and release it as it spills data out. When a task ends, all its memory will be + * released by the Executor. + * + * This class tries to ensure that each thread gets a reasonable share of memory, instead of some + * thread ramping up to a large amount first and then causing others to spill to disk repeatedly. + * If there are N threads, it ensures that each thread can acquire at least 1 / 2N of the memory + * before it has to spill, and at most 1 / N. Because N varies dynamically, we keep track of the + * set of active threads and redo the calculations of 1 / 2N and 1 / N in waiting threads whenever + * this set changes. This is all done by synchronizing access on "this" to mutate state and using + * wait() and notifyAll() to signal changes. + */ +private[spark] class ShuffleMemoryManager(maxMemory: Long) extends Logging { + private val threadMemory = new mutable.HashMap[Long, Long]() // threadId -> memory bytes + + def this(conf: SparkConf) = this(ShuffleMemoryManager.getMaxMemory(conf)) + + /** + * Try to acquire up to numBytes memory for the current thread, and return the number of bytes + * obtained, or 0 if none can be allocated. This call may block until there is enough free memory + * in some situations, to make sure each thread has a chance to ramp up to at least 1 / 2N of the + * total memory pool (where N is the # of active threads) before it is forced to spill. This can + * happen if the number of threads increases but an older thread had a lot of memory already. + */ + def tryToAcquire(numBytes: Long): Long = synchronized { + val threadId = Thread.currentThread().getId + assert(numBytes > 0, "invalid number of bytes requested: " + numBytes) + + // Add this thread to the threadMemory map just so we can keep an accurate count of the number + // of active threads, to let other threads ramp down their memory in calls to tryToAcquire + if (!threadMemory.contains(threadId)) { + threadMemory(threadId) = 0L + notifyAll() // Will later cause waiting threads to wake up and check numThreads again + } + + // Keep looping until we're either sure that we don't want to grant this request (because this + // thread would have more than 1 / numActiveThreads of the memory) or we have enough free + // memory to give it (we always let each thread get at least 1 / (2 * numActiveThreads)). + while (true) { + val numActiveThreads = threadMemory.keys.size + val curMem = threadMemory(threadId) + val freeMemory = maxMemory - threadMemory.values.sum + + // How much we can grant this thread; don't let it grow to more than 1 / numActiveThreads + val maxToGrant = math.min(numBytes, (maxMemory / numActiveThreads) - curMem) + + if (curMem < maxMemory / (2 * numActiveThreads)) { + // We want to let each thread get at least 1 / (2 * numActiveThreads) before blocking; + // if we can't give it this much now, wait for other threads to free up memory + // (this happens if older threads allocated lots of memory before N grew) + if (freeMemory >= math.min(maxToGrant, maxMemory / (2 * numActiveThreads) - curMem)) { + val toGrant = math.min(maxToGrant, freeMemory) + threadMemory(threadId) += toGrant + return toGrant + } else { + logInfo(s"Thread $threadId waiting for at least 1/2N of shuffle memory pool to be free") + wait() + } + } else { + // Only give it as much memory as is free, which might be none if it reached 1 / numThreads + val toGrant = math.min(maxToGrant, freeMemory) + threadMemory(threadId) += toGrant + return toGrant + } + } + 0L // Never reached + } + + /** Release numBytes bytes for the current thread. */ + def release(numBytes: Long): Unit = synchronized { + val threadId = Thread.currentThread().getId + val curMem = threadMemory.getOrElse(threadId, 0L) + if (curMem < numBytes) { + throw new SparkException( + s"Internal error: release called on ${numBytes} bytes but thread only has ${curMem}") + } + threadMemory(threadId) -= numBytes + notifyAll() // Notify waiters who locked "this" in tryToAcquire that memory has been freed + } + + /** Release all memory for the current thread and mark it as inactive (e.g. when a task ends). */ + def releaseMemoryForThisThread(): Unit = synchronized { + val threadId = Thread.currentThread().getId + threadMemory.remove(threadId) + notifyAll() // Notify waiters who locked "this" in tryToAcquire that memory has been freed + } +} + +private object ShuffleMemoryManager { + /** + * Figure out the shuffle memory limit from a SparkConf. We currently have both a fraction + * of the memory pool and a safety factor since collections can sometimes grow bigger than + * the size we target before we estimate their sizes again. + */ + def getMaxMemory(conf: SparkConf): Long = { + val memoryFraction = conf.getDouble("spark.shuffle.memoryFraction", 0.2) + val safetyFraction = conf.getDouble("spark.shuffle.safetyFraction", 0.8) + (Runtime.getRuntime.maxMemory * memoryFraction * safetyFraction).toLong + } +} diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala index 5b0940ecce29d..df98d18fa8193 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala @@ -24,7 +24,7 @@ import org.apache.spark.shuffle._ * A ShuffleManager using hashing, that creates one output file per reduce partition on each * mapper (possibly reusing these across waves of tasks). */ -class HashShuffleManager(conf: SparkConf) extends ShuffleManager { +private[spark] class HashShuffleManager(conf: SparkConf) extends ShuffleManager { /* Register a shuffle with the manager and obtain a handle for it to pass to tasks. */ override def registerShuffle[K, V, C]( shuffleId: Int, diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala index 76cdb8f4f8e8a..7c9dc8e5f88ef 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala @@ -18,11 +18,11 @@ package org.apache.spark.shuffle.hash import org.apache.spark.{InterruptibleIterator, TaskContext} -import org.apache.spark.rdd.SortOrder import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReader} +import org.apache.spark.util.collection.ExternalSorter -class HashShuffleReader[K, C]( +private[spark] class HashShuffleReader[K, C]( handle: BaseShuffleHandle[K, _, C], startPartition: Int, endPartition: Int, @@ -36,8 +36,8 @@ class HashShuffleReader[K, C]( /** Read the combined key-values for this reduce task */ override def read(): Iterator[Product2[K, C]] = { - val iter = BlockStoreShuffleFetcher.fetch(handle.shuffleId, startPartition, context, - Serializer.getSerializer(dep.serializer)) + val ser = Serializer.getSerializer(dep.serializer) + val iter = BlockStoreShuffleFetcher.fetch(handle.shuffleId, startPartition, context, ser) val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) { if (dep.mapSideCombine) { @@ -48,19 +48,23 @@ class HashShuffleReader[K, C]( } else if (dep.aggregator.isEmpty && dep.mapSideCombine) { throw new IllegalStateException("Aggregator is empty for map-side combine") } else { - iter + // Convert the Product2s to pairs since this is what downstream RDDs currently expect + iter.asInstanceOf[Iterator[Product2[K, C]]].map(pair => (pair._1, pair._2)) } - val sortedIter = for (sortOrder <- dep.sortOrder; ordering <- dep.keyOrdering) yield { - val buf = aggregatedIter.toArray - if (sortOrder == SortOrder.ASCENDING) { - buf.sortWith((x, y) => ordering.lt(x._1, y._1)).iterator - } else { - buf.sortWith((x, y) => ordering.gt(x._1, y._1)).iterator - } + // Sort the output if there is a sort ordering defined. + dep.keyOrdering match { + case Some(keyOrd: Ordering[K]) => + // Create an ExternalSorter to sort the data. Note that if spark.shuffle.spill is disabled, + // the ExternalSorter won't spill to disk. + val sorter = new ExternalSorter[K, C, C](ordering = Some(keyOrd), serializer = Some(ser)) + sorter.write(aggregatedIter) + context.taskMetrics.memoryBytesSpilled += sorter.memoryBytesSpilled + context.taskMetrics.diskBytesSpilled += sorter.diskBytesSpilled + sorter.iterator + case None => + aggregatedIter } - - sortedIter.getOrElse(aggregatedIter) } /** Close this reader */ diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala index 9b78228519da4..45d3b8b9b8725 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala @@ -24,7 +24,7 @@ import org.apache.spark.serializer.Serializer import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.scheduler.MapStatus -class HashShuffleWriter[K, V]( +private[spark] class HashShuffleWriter[K, V]( handle: BaseShuffleHandle[K, V, _], mapId: Int, context: TaskContext) @@ -33,6 +33,10 @@ class HashShuffleWriter[K, V]( private val dep = handle.dependency private val numOutputSplits = dep.partitioner.numPartitions private val metrics = context.taskMetrics + + // Are we in the process of stopping? Because map tasks can call stop() with success = true + // and then call stop() with success = false if they get an exception, we want to make sure + // we don't try deleting files, etc twice. private var stopping = false private val blockManager = SparkEnv.get.blockManager @@ -61,7 +65,8 @@ class HashShuffleWriter[K, V]( } /** Close this writer, passing along whether the map completed */ - override def stop(success: Boolean): Option[MapStatus] = { + override def stop(initiallySuccess: Boolean): Option[MapStatus] = { + var success = initiallySuccess try { if (stopping) { return None @@ -69,15 +74,16 @@ class HashShuffleWriter[K, V]( stopping = true if (success) { try { - return Some(commitWritesAndBuildStatus()) + Some(commitWritesAndBuildStatus()) } catch { case e: Exception => + success = false revertWrites() throw e } } else { revertWrites() - return None + None } } finally { // Release the writers back to the shuffle block manager. @@ -96,8 +102,7 @@ class HashShuffleWriter[K, V]( var totalBytes = 0L var totalTime = 0L val compressedSizes = shuffle.writers.map { writer: BlockObjectWriter => - writer.commit() - writer.close() + writer.commitAndClose() val size = writer.fileSegment().length totalBytes += size totalTime += writer.timeWriting() @@ -116,8 +121,7 @@ class HashShuffleWriter[K, V]( private def revertWrites(): Unit = { if (shuffle != null && shuffle.writers != null) { for (writer <- shuffle.writers) { - writer.revertPartialWrites() - writer.close() + writer.revertPartialWritesAndClose() } } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala new file mode 100644 index 0000000000000..6dcca47ea7c0c --- /dev/null +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.sort + +import java.io.{DataInputStream, FileInputStream} + +import org.apache.spark.shuffle._ +import org.apache.spark.{TaskContext, ShuffleDependency} +import org.apache.spark.shuffle.hash.HashShuffleReader +import org.apache.spark.storage.{DiskBlockManager, FileSegment, ShuffleBlockId} + +private[spark] class SortShuffleManager extends ShuffleManager { + /** + * Register a shuffle with the manager and obtain a handle for it to pass to tasks. + */ + override def registerShuffle[K, V, C]( + shuffleId: Int, + numMaps: Int, + dependency: ShuffleDependency[K, V, C]): ShuffleHandle = { + new BaseShuffleHandle(shuffleId, numMaps, dependency) + } + + /** + * Get a reader for a range of reduce partitions (startPartition to endPartition-1, inclusive). + * Called on executors by reduce tasks. + */ + override def getReader[K, C]( + handle: ShuffleHandle, + startPartition: Int, + endPartition: Int, + context: TaskContext): ShuffleReader[K, C] = { + // We currently use the same block store shuffle fetcher as the hash-based shuffle. + new HashShuffleReader( + handle.asInstanceOf[BaseShuffleHandle[K, _, C]], startPartition, endPartition, context) + } + + /** Get a writer for a given partition. Called on executors by map tasks. */ + override def getWriter[K, V](handle: ShuffleHandle, mapId: Int, context: TaskContext) + : ShuffleWriter[K, V] = { + new SortShuffleWriter(handle.asInstanceOf[BaseShuffleHandle[K, V, _]], mapId, context) + } + + /** Remove a shuffle's metadata from the ShuffleManager. */ + override def unregisterShuffle(shuffleId: Int): Unit = {} + + /** Shut down this ShuffleManager. */ + override def stop(): Unit = {} + + /** Get the location of a block in a map output file. Uses the index file we create for it. */ + def getBlockLocation(blockId: ShuffleBlockId, diskManager: DiskBlockManager): FileSegment = { + // The block is actually going to be a range of a single map output file for this map, so + // figure out the ID of the consolidated file, then the offset within that from our index + val consolidatedId = blockId.copy(reduceId = 0) + val indexFile = diskManager.getFile(consolidatedId.name + ".index") + val in = new DataInputStream(new FileInputStream(indexFile)) + try { + in.skip(blockId.reduceId * 8) + val offset = in.readLong() + val nextOffset = in.readLong() + new FileSegment(diskManager.getFile(consolidatedId), offset, nextOffset - offset) + } finally { + in.close() + } + } +} diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala new file mode 100644 index 0000000000000..24db2f287a47b --- /dev/null +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala @@ -0,0 +1,164 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.sort + +import java.io.{BufferedOutputStream, File, FileOutputStream, DataOutputStream} + +import org.apache.spark.{MapOutputTracker, SparkEnv, Logging, TaskContext} +import org.apache.spark.executor.ShuffleWriteMetrics +import org.apache.spark.scheduler.MapStatus +import org.apache.spark.serializer.Serializer +import org.apache.spark.shuffle.{ShuffleWriter, BaseShuffleHandle} +import org.apache.spark.storage.ShuffleBlockId +import org.apache.spark.util.collection.ExternalSorter + +private[spark] class SortShuffleWriter[K, V, C]( + handle: BaseShuffleHandle[K, V, C], + mapId: Int, + context: TaskContext) + extends ShuffleWriter[K, V] with Logging { + + private val dep = handle.dependency + private val numPartitions = dep.partitioner.numPartitions + + private val blockManager = SparkEnv.get.blockManager + private val ser = Serializer.getSerializer(dep.serializer.orNull) + + private val conf = SparkEnv.get.conf + private val fileBufferSize = conf.getInt("spark.shuffle.file.buffer.kb", 32) * 1024 + + private var sorter: ExternalSorter[K, V, _] = null + private var outputFile: File = null + + // Are we in the process of stopping? Because map tasks can call stop() with success = true + // and then call stop() with success = false if they get an exception, we want to make sure + // we don't try deleting files, etc twice. + private var stopping = false + + private var mapStatus: MapStatus = null + + /** Write a bunch of records to this task's output */ + override def write(records: Iterator[_ <: Product2[K, V]]): Unit = { + // Get an iterator with the elements for each partition ID + val partitions: Iterator[(Int, Iterator[Product2[K, _]])] = { + if (dep.mapSideCombine) { + if (!dep.aggregator.isDefined) { + throw new IllegalStateException("Aggregator is empty for map-side combine") + } + sorter = new ExternalSorter[K, V, C]( + dep.aggregator, Some(dep.partitioner), dep.keyOrdering, dep.serializer) + sorter.write(records) + sorter.partitionedIterator + } else { + // In this case we pass neither an aggregator nor an ordering to the sorter, because we + // don't care whether the keys get sorted in each partition; that will be done on the + // reduce side if the operation being run is sortByKey. + sorter = new ExternalSorter[K, V, V]( + None, Some(dep.partitioner), None, dep.serializer) + sorter.write(records) + sorter.partitionedIterator + } + } + + // Create a single shuffle file with reduce ID 0 that we'll write all results to. We'll later + // serve different ranges of this file using an index file that we create at the end. + val blockId = ShuffleBlockId(dep.shuffleId, mapId, 0) + outputFile = blockManager.diskBlockManager.getFile(blockId) + + // Track location of each range in the output file + val offsets = new Array[Long](numPartitions + 1) + val lengths = new Array[Long](numPartitions) + + // Statistics + var totalBytes = 0L + var totalTime = 0L + + for ((id, elements) <- partitions) { + if (elements.hasNext) { + val writer = blockManager.getDiskWriter(blockId, outputFile, ser, fileBufferSize) + for (elem <- elements) { + writer.write(elem) + } + writer.commitAndClose() + val segment = writer.fileSegment() + offsets(id + 1) = segment.offset + segment.length + lengths(id) = segment.length + totalTime += writer.timeWriting() + totalBytes += segment.length + } else { + // The partition is empty; don't create a new writer to avoid writing headers, etc + offsets(id + 1) = offsets(id) + } + } + + val shuffleMetrics = new ShuffleWriteMetrics + shuffleMetrics.shuffleBytesWritten = totalBytes + shuffleMetrics.shuffleWriteTime = totalTime + context.taskMetrics.shuffleWriteMetrics = Some(shuffleMetrics) + context.taskMetrics.memoryBytesSpilled += sorter.memoryBytesSpilled + context.taskMetrics.diskBytesSpilled += sorter.diskBytesSpilled + + // Write an index file with the offsets of each block, plus a final offset at the end for the + // end of the output file. This will be used by SortShuffleManager.getBlockLocation to figure + // out where each block begins and ends. + + val diskBlockManager = blockManager.diskBlockManager + val indexFile = diskBlockManager.getFile(blockId.name + ".index") + val out = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(indexFile))) + try { + var i = 0 + while (i < numPartitions + 1) { + out.writeLong(offsets(i)) + i += 1 + } + } finally { + out.close() + } + + // Register our map output with the ShuffleBlockManager, which handles cleaning it over time + blockManager.shuffleBlockManager.addCompletedMap(dep.shuffleId, mapId, numPartitions) + + mapStatus = new MapStatus(blockManager.blockManagerId, + lengths.map(MapOutputTracker.compressSize)) + } + + /** Close this writer, passing along whether the map completed */ + override def stop(success: Boolean): Option[MapStatus] = { + try { + if (stopping) { + return None + } + stopping = true + if (success) { + return Option(mapStatus) + } else { + // The map task failed, so delete our output file if we created one + if (outputFile != null) { + outputFile.delete() + } + return None + } + } finally { + // Clean up our sorter, which may have its own intermediate files + if (sorter != null) { + sorter.stop() + sorter = null + } + } + } +} diff --git a/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala index 69905a960a2ca..ccf830e118ee7 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala @@ -200,14 +200,17 @@ object BlockFetcherIterator { // these all at once because they will just memory-map some files, so they won't consume // any memory that might exceed our maxBytesInFlight for (id <- localBlocksToFetch) { - getLocalFromDisk(id, serializer) match { - case Some(iter) => { - // Pass 0 as size since it's not in flight - results.put(new FetchResult(id, 0, () => iter)) - logDebug("Got local block " + id) - } - case None => { - throw new BlockException(id, "Could not get block " + id + " from local machine") + try { + // getLocalFromDisk never return None but throws BlockException + val iter = getLocalFromDisk(id, serializer).get + // Pass 0 as size since it's not in flight + results.put(new FetchResult(id, 0, () => iter)) + logDebug("Got local block " + id) + } catch { + case e: Exception => { + logError(s"Error occurred while fetching local blocks", e) + results.put(new FetchResult(id, -1, null)) + return } } } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockId.scala b/core/src/main/scala/org/apache/spark/storage/BlockId.scala index 42ec181b00bb3..c1756ac905417 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockId.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockId.scala @@ -54,11 +54,15 @@ case class RDDBlockId(rddId: Int, splitIndex: Int) extends BlockId { } @DeveloperApi -case class ShuffleBlockId(shuffleId: Int, mapId: Int, reduceId: Int) - extends BlockId { +case class ShuffleBlockId(shuffleId: Int, mapId: Int, reduceId: Int) extends BlockId { def name = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId } +@DeveloperApi +case class ShuffleIndexBlockId(shuffleId: Int, mapId: Int, reduceId: Int) extends BlockId { + def name = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId + ".index" +} + @DeveloperApi case class BroadcastBlockId(broadcastId: Long, field: String = "") extends BlockId { def name = "broadcast_" + broadcastId + (if (field == "") "" else "_" + field) @@ -88,6 +92,7 @@ private[spark] case class TestBlockId(id: String) extends BlockId { object BlockId { val RDD = "rdd_([0-9]+)_([0-9]+)".r val SHUFFLE = "shuffle_([0-9]+)_([0-9]+)_([0-9]+)".r + val SHUFFLE_INDEX = "shuffle_([0-9]+)_([0-9]+)_([0-9]+).index".r val BROADCAST = "broadcast_([0-9]+)([_A-Za-z0-9]*)".r val TASKRESULT = "taskresult_([0-9]+)".r val STREAM = "input-([0-9]+)-([0-9]+)".r @@ -99,6 +104,8 @@ object BlockId { RDDBlockId(rddId.toInt, splitIndex.toInt) case SHUFFLE(shuffleId, mapId, reduceId) => ShuffleBlockId(shuffleId.toInt, mapId.toInt, reduceId.toInt) + case SHUFFLE_INDEX(shuffleId, mapId, reduceId) => + ShuffleIndexBlockId(shuffleId.toInt, mapId.toInt, reduceId.toInt) case BROADCAST(broadcastId, field) => BroadcastBlockId(broadcastId.toLong, field.stripPrefix("_")) case TASKRESULT(taskId) => diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index d746526639e58..3876cf43e2a7d 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -60,10 +60,12 @@ private[spark] class BlockManager( mapOutputTracker: MapOutputTracker) extends Logging { + private val port = conf.getInt("spark.blockManager.port", 0) val shuffleBlockManager = new ShuffleBlockManager(this) val diskBlockManager = new DiskBlockManager(shuffleBlockManager, conf.get("spark.local.dir", System.getProperty("java.io.tmpdir"))) - val connectionManager = new ConnectionManager(0, conf, securityManager) + val connectionManager = + new ConnectionManager(port, conf, securityManager, "Connection manager for block manager") implicit val futureExecContext = connectionManager.futureExecContext @@ -116,15 +118,6 @@ private[spark] class BlockManager( private var asyncReregisterTask: Future[Unit] = null private val asyncReregisterLock = new Object - private def heartBeat(): Unit = { - if (!master.sendHeartBeat(blockManagerId)) { - reregister() - } - } - - private val heartBeatFrequency = BlockManager.getHeartBeatFrequency(conf) - private var heartBeatTask: Cancellable = null - private val metadataCleaner = new MetadataCleaner( MetadataCleanerType.BLOCK_MANAGER, this.dropOldNonBroadcastBlocks, conf) private val broadcastCleaner = new MetadataCleaner( @@ -161,11 +154,6 @@ private[spark] class BlockManager( private def initialize(): Unit = { master.registerBlockManager(blockManagerId, maxMemory, slaveActor) BlockManagerWorker.startBlockManagerWorker(this) - if (!BlockManager.getDisableHeartBeatsForTesting(conf)) { - heartBeatTask = actorSystem.scheduler.schedule(0.seconds, heartBeatFrequency.milliseconds) { - Utils.tryOrExit { heartBeat() } - } - } } /** @@ -195,7 +183,7 @@ private[spark] class BlockManager( * * Note that this method must be called without any BlockInfo locks held. */ - private def reregister(): Unit = { + def reregister(): Unit = { // TODO: We might need to rate limit re-registering. logInfo("BlockManager re-registering with master") master.registerBlockManager(blockManagerId, maxMemory, slaveActor) @@ -1065,9 +1053,6 @@ private[spark] class BlockManager( } def stop(): Unit = { - if (heartBeatTask != null) { - heartBeatTask.cancel() - } connectionManager.stop() shuffleBlockManager.stop() diskBlockManager.stop() @@ -1095,12 +1080,6 @@ private[spark] object BlockManager extends Logging { (Runtime.getRuntime.maxMemory * memoryFraction * safetyFraction).toLong } - def getHeartBeatFrequency(conf: SparkConf): Long = - conf.getLong("spark.storage.blockManagerTimeoutIntervalMs", 60000) / 4 - - def getDisableHeartBeatsForTesting(conf: SparkConf): Boolean = - conf.getBoolean("spark.test.disableBlockManagerHeartBeat", false) - /** * Attempt to clean up a ByteBuffer if it is memory-mapped. This uses an *unsafe* Sun API that * might cause errors if one attempts to read from the unmapped buffer, but it's better than diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala index 7897fade2df2b..669307765d1fa 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala @@ -21,7 +21,6 @@ import scala.concurrent.{Await, Future} import scala.concurrent.ExecutionContext.Implicits.global import akka.actor._ -import akka.pattern.ask import org.apache.spark.{Logging, SparkConf, SparkException} import org.apache.spark.storage.BlockManagerMessages._ @@ -29,8 +28,8 @@ import org.apache.spark.util.AkkaUtils private[spark] class BlockManagerMaster(var driverActor: ActorRef, conf: SparkConf) extends Logging { - val AKKA_RETRY_ATTEMPTS: Int = conf.getInt("spark.akka.num.retries", 3) - val AKKA_RETRY_INTERVAL_MS: Int = conf.getInt("spark.akka.retry.wait", 3000) + private val AKKA_RETRY_ATTEMPTS: Int = AkkaUtils.numRetries(conf) + private val AKKA_RETRY_INTERVAL_MS: Int = AkkaUtils.retryWaitMs(conf) val DRIVER_AKKA_ACTOR_NAME = "BlockManagerMaster" @@ -42,15 +41,6 @@ class BlockManagerMaster(var driverActor: ActorRef, conf: SparkConf) extends Log logInfo("Removed " + execId + " successfully in removeExecutor") } - /** - * Send the driver actor a heart beat from the slave. Returns true if everything works out, - * false if the driver does not know about the given block manager, which means the block - * manager should re-register. - */ - def sendHeartBeat(blockManagerId: BlockManagerId): Boolean = { - askDriverWithReply[Boolean](HeartBeat(blockManagerId)) - } - /** Register the BlockManager's id with the driver. */ def registerBlockManager(blockManagerId: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) { logInfo("Trying to register BlockManager") @@ -223,33 +213,8 @@ class BlockManagerMaster(var driverActor: ActorRef, conf: SparkConf) extends Log * throw a SparkException if this fails. */ private def askDriverWithReply[T](message: Any): T = { - // TODO: Consider removing multiple attempts - if (driverActor == null) { - throw new SparkException("Error sending message to BlockManager as driverActor is null " + - "[message = " + message + "]") - } - var attempts = 0 - var lastException: Exception = null - while (attempts < AKKA_RETRY_ATTEMPTS) { - attempts += 1 - try { - val future = driverActor.ask(message)(timeout) - val result = Await.result(future, timeout) - if (result == null) { - throw new SparkException("BlockManagerMaster returned null") - } - return result.asInstanceOf[T] - } catch { - case ie: InterruptedException => throw ie - case e: Exception => - lastException = e - logWarning("Error sending message to BlockManagerMaster in " + attempts + " attempts", e) - } - Thread.sleep(AKKA_RETRY_INTERVAL_MS) - } - - throw new SparkException( - "Error sending message to BlockManagerMaster [message = " + message + "]", lastException) + AkkaUtils.askWithReply(message, driverActor, AKKA_RETRY_ATTEMPTS, AKKA_RETRY_INTERVAL_MS, + timeout) } } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala index de1cc5539fb48..bd31e3c5a187f 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala @@ -52,25 +52,24 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus private val akkaTimeout = AkkaUtils.askTimeout(conf) - val slaveTimeout = conf.get("spark.storage.blockManagerSlaveTimeoutMs", - "" + (BlockManager.getHeartBeatFrequency(conf) * 3)).toLong + val slaveTimeout = conf.getLong("spark.storage.blockManagerSlaveTimeoutMs", + math.max(conf.getInt("spark.executor.heartbeatInterval", 10000) * 3, 45000)) - val checkTimeoutInterval = conf.get("spark.storage.blockManagerTimeoutIntervalMs", - "60000").toLong + val checkTimeoutInterval = conf.getLong("spark.storage.blockManagerTimeoutIntervalMs", + 60000) var timeoutCheckingTask: Cancellable = null override def preStart() { - if (!BlockManager.getDisableHeartBeatsForTesting(conf)) { - import context.dispatcher - timeoutCheckingTask = context.system.scheduler.schedule(0.seconds, - checkTimeoutInterval.milliseconds, self, ExpireDeadHosts) - } + import context.dispatcher + timeoutCheckingTask = context.system.scheduler.schedule(0.seconds, + checkTimeoutInterval.milliseconds, self, ExpireDeadHosts) super.preStart() } def receive = { case RegisterBlockManager(blockManagerId, maxMemSize, slaveActor) => + logInfo("received a register") register(blockManagerId, maxMemSize, slaveActor) sender ! true @@ -129,8 +128,8 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus case ExpireDeadHosts => expireDeadHosts() - case HeartBeat(blockManagerId) => - sender ! heartBeat(blockManagerId) + case BlockManagerHeartbeat(blockManagerId) => + sender ! heartbeatReceived(blockManagerId) case other => logWarning("Got unknown message: " + other) @@ -216,7 +215,7 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus val minSeenTime = now - slaveTimeout val toRemove = new mutable.HashSet[BlockManagerId] for (info <- blockManagerInfo.values) { - if (info.lastSeenMs < minSeenTime) { + if (info.lastSeenMs < minSeenTime && info.blockManagerId.executorId != "") { logWarning("Removing BlockManager " + info.blockManagerId + " with no recent heart beats: " + (now - info.lastSeenMs) + "ms exceeds " + slaveTimeout + "ms") toRemove += info.blockManagerId @@ -230,7 +229,11 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus blockManagerIdByExecutor.get(execId).foreach(removeBlockManager) } - private def heartBeat(blockManagerId: BlockManagerId): Boolean = { + /** + * Return true if the driver knows about the given block manager. Otherwise, return false, + * indicating that the block manager should re-register. + */ + private def heartbeatReceived(blockManagerId: BlockManagerId): Boolean = { if (!blockManagerInfo.contains(blockManagerId)) { blockManagerId.executorId == "" && !isLocal } else { @@ -264,9 +267,8 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus } private def storageStatus: Array[StorageStatus] = { - blockManagerInfo.map { case(blockManagerId, info) => - val blockMap = mutable.Map[BlockId, BlockStatus](info.blocks.toSeq: _*) - new StorageStatus(blockManagerId, info.maxMem, blockMap) + blockManagerInfo.map { case (blockManagerId, info) => + new StorageStatus(blockManagerId, info.maxMem, info.blocks) }.toArray } @@ -421,7 +423,14 @@ case class BlockStatus( storageLevel: StorageLevel, memSize: Long, diskSize: Long, - tachyonSize: Long) + tachyonSize: Long) { + def isCached: Boolean = memSize + diskSize + tachyonSize > 0 +} + +@DeveloperApi +object BlockStatus { + def empty: BlockStatus = BlockStatus(StorageLevel.NONE, 0L, 0L, 0L) +} private[spark] class BlockManagerInfo( val blockManagerId: BlockManagerId, diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala index 2b53bf33b5fba..10b65286fb7db 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala @@ -21,7 +21,7 @@ import java.io.{Externalizable, ObjectInput, ObjectOutput} import akka.actor.ActorRef -private[storage] object BlockManagerMessages { +private[spark] object BlockManagerMessages { ////////////////////////////////////////////////////////////////////////////////// // Messages from the master to slaves. ////////////////////////////////////////////////////////////////////////////////// @@ -53,8 +53,6 @@ private[storage] object BlockManagerMessages { sender: ActorRef) extends ToBlockManagerMaster - case class HeartBeat(blockManagerId: BlockManagerId) extends ToBlockManagerMaster - class UpdateBlockInfo( var blockManagerId: BlockManagerId, var blockId: BlockId, @@ -124,5 +122,7 @@ private[storage] object BlockManagerMessages { case class GetMatchingBlockIds(filter: BlockId => Boolean, askSlaves: Boolean = true) extends ToBlockManagerMaster + case class BlockManagerHeartbeat(blockManagerId: BlockManagerId) extends ToBlockManagerMaster + case object ExpireDeadHosts extends ToBlockManagerMaster } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSource.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSource.scala index 687586490abfe..3f14c40ec61cb 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSource.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSource.scala @@ -30,7 +30,7 @@ private[spark] class BlockManagerSource(val blockManager: BlockManager, sc: Spar metricRegistry.register(MetricRegistry.name("memory", "maxMem_MB"), new Gauge[Long] { override def getValue: Long = { val storageStatusList = blockManager.master.getStorageStatus - val maxMem = storageStatusList.map(_.maxMem).reduce(_ + _) + val maxMem = storageStatusList.map(_.maxMem).sum maxMem / 1024 / 1024 } }) @@ -38,7 +38,7 @@ private[spark] class BlockManagerSource(val blockManager: BlockManager, sc: Spar metricRegistry.register(MetricRegistry.name("memory", "remainingMem_MB"), new Gauge[Long] { override def getValue: Long = { val storageStatusList = blockManager.master.getStorageStatus - val remainingMem = storageStatusList.map(_.memRemaining).reduce(_ + _) + val remainingMem = storageStatusList.map(_.memRemaining).sum remainingMem / 1024 / 1024 } }) @@ -46,20 +46,15 @@ private[spark] class BlockManagerSource(val blockManager: BlockManager, sc: Spar metricRegistry.register(MetricRegistry.name("memory", "memUsed_MB"), new Gauge[Long] { override def getValue: Long = { val storageStatusList = blockManager.master.getStorageStatus - val maxMem = storageStatusList.map(_.maxMem).reduce(_ + _) - val remainingMem = storageStatusList.map(_.memRemaining).reduce(_ + _) - (maxMem - remainingMem) / 1024 / 1024 + val memUsed = storageStatusList.map(_.memUsed).sum + memUsed / 1024 / 1024 } }) metricRegistry.register(MetricRegistry.name("disk", "diskSpaceUsed_MB"), new Gauge[Long] { override def getValue: Long = { val storageStatusList = blockManager.master.getStorageStatus - val diskSpaceUsed = storageStatusList - .flatMap(_.blocks.values.map(_.diskSize)) - .reduceOption(_ + _) - .getOrElse(0L) - + val diskSpaceUsed = storageStatusList.map(_.diskUsed).sum diskSpaceUsed / 1024 / 1024 } }) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala index a2687e6be4e34..01d46e1ffc960 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala @@ -39,16 +39,16 @@ private[spark] abstract class BlockObjectWriter(val blockId: BlockId) { def isOpen: Boolean /** - * Flush the partial writes and commit them as a single atomic block. Return the - * number of bytes written for this commit. + * Flush the partial writes and commit them as a single atomic block. */ - def commit(): Long + def commitAndClose(): Unit /** * Reverts writes that haven't been flushed yet. Callers should invoke this function - * when there are runtime exceptions. + * when there are runtime exceptions. This method will not throw, though it may be + * unsuccessful in truncating written data. */ - def revertPartialWrites() + def revertPartialWritesAndClose() /** * Writes an object. @@ -57,6 +57,7 @@ private[spark] abstract class BlockObjectWriter(val blockId: BlockId) { /** * Returns the file segment of committed data that this Writer has written. + * This is only valid after commitAndClose() has been called. */ def fileSegment(): FileSegment @@ -108,7 +109,7 @@ private[spark] class DiskBlockObjectWriter( private var ts: TimeTrackingOutputStream = null private var objOut: SerializationStream = null private val initialPosition = file.length() - private var lastValidPosition = initialPosition + private var finalPosition: Long = -1 private var initialized = false private var _timeWriting = 0L @@ -116,7 +117,6 @@ private[spark] class DiskBlockObjectWriter( fos = new FileOutputStream(file, true) ts = new TimeTrackingOutputStream(fos) channel = fos.getChannel() - lastValidPosition = initialPosition bs = compressStream(new BufferedOutputStream(ts, bufferSize)) objOut = serializer.newInstance().serializeStream(bs) initialized = true @@ -147,28 +147,36 @@ private[spark] class DiskBlockObjectWriter( override def isOpen: Boolean = objOut != null - override def commit(): Long = { + override def commitAndClose(): Unit = { if (initialized) { // NOTE: Because Kryo doesn't flush the underlying stream we explicitly flush both the // serializer stream and the lower level stream. objOut.flush() bs.flush() - val prevPos = lastValidPosition - lastValidPosition = channel.position() - lastValidPosition - prevPos - } else { - // lastValidPosition is zero if stream is uninitialized - lastValidPosition + close() } + finalPosition = file.length() } - override def revertPartialWrites() { - if (initialized) { - // Discard current writes. We do this by flushing the outstanding writes and - // truncate the file to the last valid position. - objOut.flush() - bs.flush() - channel.truncate(lastValidPosition) + // Discard current writes. We do this by flushing the outstanding writes and then + // truncating the file to its initial position. + override def revertPartialWritesAndClose() { + try { + if (initialized) { + objOut.flush() + bs.flush() + close() + } + + val truncateStream = new FileOutputStream(file, true) + try { + truncateStream.getChannel.truncate(initialPosition) + } finally { + truncateStream.close() + } + } catch { + case e: Exception => + logError("Uncaught exception while reverting partial writes to file " + file, e) } } @@ -188,6 +196,7 @@ private[spark] class DiskBlockObjectWriter( // Only valid if called after commit() override def bytesWritten: Long = { - lastValidPosition - initialPosition + assert(finalPosition != -1, "bytesWritten is only valid after successful commit()") + finalPosition - initialPosition } } diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala index 2e7ed7538e6e5..4d66ccea211fa 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala @@ -21,10 +21,11 @@ import java.io.File import java.text.SimpleDateFormat import java.util.{Date, Random, UUID} -import org.apache.spark.Logging +import org.apache.spark.{SparkEnv, Logging} import org.apache.spark.executor.ExecutorExitCode import org.apache.spark.network.netty.{PathResolver, ShuffleSender} import org.apache.spark.util.Utils +import org.apache.spark.shuffle.sort.SortShuffleManager /** * Creates and maintains the logical mapping between logical blocks and physical on-disk @@ -34,11 +35,13 @@ import org.apache.spark.util.Utils * * @param rootDirs The directories to use for storing block files. Data will be hashed among these. */ -private[spark] class DiskBlockManager(shuffleManager: ShuffleBlockManager, rootDirs: String) +private[spark] class DiskBlockManager(shuffleBlockManager: ShuffleBlockManager, rootDirs: String) extends PathResolver with Logging { private val MAX_DIR_CREATION_ATTEMPTS: Int = 10 - private val subDirsPerLocalDir = shuffleManager.conf.getInt("spark.diskStore.subDirectories", 64) + + private val subDirsPerLocalDir = + shuffleBlockManager.conf.getInt("spark.diskStore.subDirectories", 64) /* Create one local directory for each path mentioned in spark.local.dir; then, inside this * directory, create multiple subdirectories that we will hash files into, in order to avoid @@ -54,13 +57,19 @@ private[spark] class DiskBlockManager(shuffleManager: ShuffleBlockManager, rootD addShutdownHook() /** - * Returns the physical file segment in which the given BlockId is located. - * If the BlockId has been mapped to a specific FileSegment, that will be returned. - * Otherwise, we assume the Block is mapped to a whole file identified by the BlockId directly. + * Returns the physical file segment in which the given BlockId is located. If the BlockId has + * been mapped to a specific FileSegment by the shuffle layer, that will be returned. + * Otherwise, we assume the Block is mapped to the whole file identified by the BlockId. */ def getBlockLocation(blockId: BlockId): FileSegment = { - if (blockId.isShuffle && shuffleManager.consolidateShuffleFiles) { - shuffleManager.getBlockLocation(blockId.asInstanceOf[ShuffleBlockId]) + val env = SparkEnv.get // NOTE: can be null in unit tests + if (blockId.isShuffle && env != null && env.shuffleManager.isInstanceOf[SortShuffleManager]) { + // For sort-based shuffle, let it figure out its blocks + val sortShuffleManager = env.shuffleManager.asInstanceOf[SortShuffleManager] + sortShuffleManager.getBlockLocation(blockId.asInstanceOf[ShuffleBlockId], this) + } else if (blockId.isShuffle && shuffleBlockManager.consolidateShuffleFiles) { + // For hash-based shuffle with consolidated files, ShuffleBlockManager takes care of this + shuffleBlockManager.getBlockLocation(blockId.asInstanceOf[ShuffleBlockId]) } else { val file = getFile(blockId.name) new FileSegment(file, 0, file.length()) @@ -99,13 +108,18 @@ private[spark] class DiskBlockManager(shuffleManager: ShuffleBlockManager, rootD getBlockLocation(blockId).file.exists() } - /** List all the blocks currently stored on disk by the disk manager. */ - def getAllBlocks(): Seq[BlockId] = { + /** List all the files currently stored on disk by the disk manager. */ + def getAllFiles(): Seq[File] = { // Get all the files inside the array of array of directories subDirs.flatten.filter(_ != null).flatMap { dir => - val files = dir.list() + val files = dir.listFiles() if (files != null) files else Seq.empty - }.map(BlockId.apply) + } + } + + /** List all the blocks currently stored on disk by the disk manager. */ + def getAllBlocks(): Seq[BlockId] = { + getAllFiles().map(f => BlockId(f.getName)) } /** Produces a unique block id and File suitable for intermediate results. */ diff --git a/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala b/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala index 5a72e216872a6..120c327a7e580 100644 --- a/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala +++ b/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala @@ -34,6 +34,8 @@ class RDDInfo( var diskSize = 0L var tachyonSize = 0L + def isCached: Boolean = (memSize + diskSize + tachyonSize > 0) && numCachedPartitions > 0 + override def toString = { import Utils.bytesToString ("RDD \"%s\" (%d) StorageLevel: %s; CachedPartitions: %d; TotalPartitions: %d; " + diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala index 35910e552fe86..f9fdffae8bd8f 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala @@ -28,6 +28,7 @@ import org.apache.spark.serializer.Serializer import org.apache.spark.storage.ShuffleBlockManager.ShuffleFileGroup import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashMap} import org.apache.spark.util.collection.{PrimitiveKeyOpenHashMap, PrimitiveVector} +import org.apache.spark.shuffle.sort.SortShuffleManager /** A group of writers for a ShuffleMapTask, one writer per reducer. */ private[spark] trait ShuffleWriterGroup { @@ -58,6 +59,7 @@ private[spark] trait ShuffleWriterGroup { * each block stored in each file. In order to find the location of a shuffle block, we search the * files within a ShuffleFileGroups associated with the block's reducer. */ +// TODO: Factor this into a separate class for each ShuffleManager implementation private[spark] class ShuffleBlockManager(blockManager: BlockManager) extends Logging { def conf = blockManager.conf @@ -67,7 +69,11 @@ class ShuffleBlockManager(blockManager: BlockManager) extends Logging { val consolidateShuffleFiles = conf.getBoolean("spark.shuffle.consolidateFiles", false) - private val bufferSize = conf.getInt("spark.shuffle.file.buffer.kb", 100) * 1024 + // Are we using sort-based shuffle? + val sortBasedShuffle = + conf.get("spark.shuffle.manager", "") == classOf[SortShuffleManager].getName + + private val bufferSize = conf.getInt("spark.shuffle.file.buffer.kb", 32) * 1024 /** * Contains all the state related to a particular shuffle. This includes a pool of unused @@ -91,6 +97,20 @@ class ShuffleBlockManager(blockManager: BlockManager) extends Logging { private val metadataCleaner = new MetadataCleaner(MetadataCleanerType.SHUFFLE_BLOCK_MANAGER, this.cleanup, conf) + /** + * Register a completed map without getting a ShuffleWriterGroup. Used by sort-based shuffle + * because it just writes a single file by itself. + */ + def addCompletedMap(shuffleId: Int, mapId: Int, numBuckets: Int): Unit = { + shuffleStates.putIfAbsent(shuffleId, new ShuffleState(numBuckets)) + val shuffleState = shuffleStates(shuffleId) + shuffleState.completedMapTasks.add(mapId) + } + + /** + * Get a ShuffleWriterGroup for the given map task, which will register it as complete + * when the writers are closed successfully + */ def forMapTask(shuffleId: Int, mapId: Int, numBuckets: Int, serializer: Serializer) = { new ShuffleWriterGroup { shuffleStates.putIfAbsent(shuffleId, new ShuffleState(numBuckets)) @@ -124,7 +144,8 @@ class ShuffleBlockManager(blockManager: BlockManager) extends Logging { if (consolidateShuffleFiles) { if (success) { val offsets = writers.map(_.fileSegment().offset) - fileGroup.recordMapOutput(mapId, offsets) + val lengths = writers.map(_.fileSegment().length) + fileGroup.recordMapOutput(mapId, offsets, lengths) } recycleFileGroup(fileGroup) } else { @@ -182,7 +203,14 @@ class ShuffleBlockManager(blockManager: BlockManager) extends Logging { private def removeShuffleBlocks(shuffleId: ShuffleId): Boolean = { shuffleStates.get(shuffleId) match { case Some(state) => - if (consolidateShuffleFiles) { + if (sortBasedShuffle) { + // There's a single block ID for each map, plus an index file for it + for (mapId <- state.completedMapTasks) { + val blockId = new ShuffleBlockId(shuffleId, mapId, 0) + blockManager.diskBlockManager.getFile(blockId).delete() + blockManager.diskBlockManager.getFile(blockId.name + ".index").delete() + } + } else if (consolidateShuffleFiles) { for (fileGroup <- state.allFileGroups; file <- fileGroup.files) { file.delete() } @@ -220,6 +248,8 @@ object ShuffleBlockManager { * A particular mapper will be assigned a single ShuffleFileGroup to write its output to. */ private class ShuffleFileGroup(val shuffleId: Int, val fileId: Int, val files: Array[File]) { + private var numBlocks: Int = 0 + /** * Stores the absolute index of each mapId in the files of this group. For instance, * if mapId 5 is the first block in each file, mapIdToIndex(5) = 0. @@ -227,23 +257,27 @@ object ShuffleBlockManager { private val mapIdToIndex = new PrimitiveKeyOpenHashMap[Int, Int]() /** - * Stores consecutive offsets of blocks into each reducer file, ordered by position in the file. - * This ordering allows us to compute block lengths by examining the following block offset. + * Stores consecutive offsets and lengths of blocks into each reducer file, ordered by + * position in the file. * Note: mapIdToIndex(mapId) returns the index of the mapper into the vector for every * reducer. */ private val blockOffsetsByReducer = Array.fill[PrimitiveVector[Long]](files.length) { new PrimitiveVector[Long]() } - - def numBlocks = mapIdToIndex.size + private val blockLengthsByReducer = Array.fill[PrimitiveVector[Long]](files.length) { + new PrimitiveVector[Long]() + } def apply(bucketId: Int) = files(bucketId) - def recordMapOutput(mapId: Int, offsets: Array[Long]) { + def recordMapOutput(mapId: Int, offsets: Array[Long], lengths: Array[Long]) { + assert(offsets.length == lengths.length) mapIdToIndex(mapId) = numBlocks + numBlocks += 1 for (i <- 0 until offsets.length) { blockOffsetsByReducer(i) += offsets(i) + blockLengthsByReducer(i) += lengths(i) } } @@ -251,16 +285,11 @@ object ShuffleBlockManager { def getFileSegmentFor(mapId: Int, reducerId: Int): Option[FileSegment] = { val file = files(reducerId) val blockOffsets = blockOffsetsByReducer(reducerId) + val blockLengths = blockLengthsByReducer(reducerId) val index = mapIdToIndex.getOrElse(mapId, -1) if (index >= 0) { val offset = blockOffsets(index) - val length = - if (index + 1 < numBlocks) { - blockOffsets(index + 1) - offset - } else { - file.length() - offset - } - assert(length >= 0) + val length = blockLengths(index) Some(new FileSegment(file, offset, length)) } else { None diff --git a/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala b/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala index 41c960c867e2e..d9066f766476e 100644 --- a/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala +++ b/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala @@ -35,13 +35,12 @@ class StorageStatusListener extends SparkListener { /** Update storage status list to reflect updated block statuses */ private def updateStorageStatus(execId: String, updatedBlocks: Seq[(BlockId, BlockStatus)]) { - val filteredStatus = executorIdToStorageStatus.get(execId) - filteredStatus.foreach { storageStatus => + executorIdToStorageStatus.get(execId).foreach { storageStatus => updatedBlocks.foreach { case (blockId, updatedStatus) => if (updatedStatus.storageLevel == StorageLevel.NONE) { - storageStatus.blocks.remove(blockId) + storageStatus.removeBlock(blockId) } else { - storageStatus.blocks(blockId) = updatedStatus + storageStatus.updateBlock(blockId, updatedStatus) } } } @@ -50,9 +49,8 @@ class StorageStatusListener extends SparkListener { /** Update storage status list to reflect the removal of an RDD from the cache */ private def updateStorageStatus(unpersistedRDDId: Int) { storageStatusList.foreach { storageStatus => - val unpersistedBlocksIds = storageStatus.rddBlocks.keys.filter(_.rddId == unpersistedRDDId) - unpersistedBlocksIds.foreach { blockId => - storageStatus.blocks.remove(blockId) + storageStatus.rddBlocksById(unpersistedRDDId).foreach { case (blockId, _) => + storageStatus.removeBlock(blockId) } } } diff --git a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala index 177281f663367..2bd6b749be261 100644 --- a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala +++ b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala @@ -20,122 +20,255 @@ package org.apache.spark.storage import scala.collection.Map import scala.collection.mutable -import org.apache.spark.SparkContext import org.apache.spark.annotation.DeveloperApi /** * :: DeveloperApi :: * Storage information for each BlockManager. + * + * This class assumes BlockId and BlockStatus are immutable, such that the consumers of this + * class cannot mutate the source of the information. Accesses are not thread-safe. */ @DeveloperApi -class StorageStatus( - val blockManagerId: BlockManagerId, - val maxMem: Long, - val blocks: mutable.Map[BlockId, BlockStatus] = mutable.Map.empty) { +class StorageStatus(val blockManagerId: BlockManagerId, val maxMem: Long) { - def memUsed = blocks.values.map(_.memSize).reduceOption(_ + _).getOrElse(0L) + /** + * Internal representation of the blocks stored in this block manager. + * + * We store RDD blocks and non-RDD blocks separately to allow quick retrievals of RDD blocks. + * These collections should only be mutated through the add/update/removeBlock methods. + */ + private val _rddBlocks = new mutable.HashMap[Int, mutable.Map[BlockId, BlockStatus]] + private val _nonRddBlocks = new mutable.HashMap[BlockId, BlockStatus] - def memUsedByRDD(rddId: Int) = - rddBlocks.filterKeys(_.rddId == rddId).values.map(_.memSize).reduceOption(_ + _).getOrElse(0L) + /** + * Storage information of the blocks that entails memory, disk, and off-heap memory usage. + * + * As with the block maps, we store the storage information separately for RDD blocks and + * non-RDD blocks for the same reason. In particular, RDD storage information is stored + * in a map indexed by the RDD ID to the following 4-tuple: + * + * (memory size, disk size, off-heap size, storage level) + * + * We assume that all the blocks that belong to the same RDD have the same storage level. + * This field is not relevant to non-RDD blocks, however, so the storage information for + * non-RDD blocks contains only the first 3 fields (in the same order). + */ + private val _rddStorageInfo = new mutable.HashMap[Int, (Long, Long, Long, StorageLevel)] + private var _nonRddStorageInfo: (Long, Long, Long) = (0L, 0L, 0L) - def diskUsed = blocks.values.map(_.diskSize).reduceOption(_ + _).getOrElse(0L) + /** Create a storage status with an initial set of blocks, leaving the source unmodified. */ + def this(bmid: BlockManagerId, maxMem: Long, initialBlocks: Map[BlockId, BlockStatus]) { + this(bmid, maxMem) + initialBlocks.foreach { case (bid, bstatus) => addBlock(bid, bstatus) } + } - def diskUsedByRDD(rddId: Int) = - rddBlocks.filterKeys(_.rddId == rddId).values.map(_.diskSize).reduceOption(_ + _).getOrElse(0L) + /** + * Return the blocks stored in this block manager. + * + * Note that this is somewhat expensive, as it involves cloning the underlying maps and then + * concatenating them together. Much faster alternatives exist for common operations such as + * contains, get, and size. + */ + def blocks: Map[BlockId, BlockStatus] = _nonRddBlocks ++ rddBlocks - def memRemaining: Long = maxMem - memUsed + /** + * Return the RDD blocks stored in this block manager. + * + * Note that this is somewhat expensive, as it involves cloning the underlying maps and then + * concatenating them together. Much faster alternatives exist for common operations such as + * getting the memory, disk, and off-heap memory sizes occupied by this RDD. + */ + def rddBlocks: Map[BlockId, BlockStatus] = _rddBlocks.flatMap { case (_, blocks) => blocks } - def rddBlocks = blocks.collect { case (rdd: RDDBlockId, status) => (rdd, status) } -} + /** Return the blocks that belong to the given RDD stored in this block manager. */ + def rddBlocksById(rddId: Int): Map[BlockId, BlockStatus] = { + _rddBlocks.get(rddId).getOrElse(Map.empty) + } -/** Helper methods for storage-related objects. */ -private[spark] object StorageUtils { + /** Add the given block to this storage status. If it already exists, overwrite it. */ + private[spark] def addBlock(blockId: BlockId, blockStatus: BlockStatus): Unit = { + updateStorageInfo(blockId, blockStatus) + blockId match { + case RDDBlockId(rddId, _) => + _rddBlocks.getOrElseUpdate(rddId, new mutable.HashMap)(blockId) = blockStatus + case _ => + _nonRddBlocks(blockId) = blockStatus + } + } + + /** Update the given block in this storage status. If it doesn't already exist, add it. */ + private[spark] def updateBlock(blockId: BlockId, blockStatus: BlockStatus): Unit = { + addBlock(blockId, blockStatus) + } + + /** Remove the given block from this storage status. */ + private[spark] def removeBlock(blockId: BlockId): Option[BlockStatus] = { + updateStorageInfo(blockId, BlockStatus.empty) + blockId match { + case RDDBlockId(rddId, _) => + // Actually remove the block, if it exists + if (_rddBlocks.contains(rddId)) { + val removed = _rddBlocks(rddId).remove(blockId) + // If the given RDD has no more blocks left, remove the RDD + if (_rddBlocks(rddId).isEmpty) { + _rddBlocks.remove(rddId) + } + removed + } else { + None + } + case _ => + _nonRddBlocks.remove(blockId) + } + } /** - * Returns basic information of all RDDs persisted in the given SparkContext. This does not - * include storage information. + * Return whether the given block is stored in this block manager in O(1) time. + * Note that this is much faster than `this.blocks.contains`, which is O(blocks) time. */ - def rddInfoFromSparkContext(sc: SparkContext): Array[RDDInfo] = { - sc.persistentRdds.values.map { rdd => - val rddName = Option(rdd.name).getOrElse(rdd.id.toString) - val rddNumPartitions = rdd.partitions.size - val rddStorageLevel = rdd.getStorageLevel - val rddInfo = new RDDInfo(rdd.id, rddName, rddNumPartitions, rddStorageLevel) - rddInfo - }.toArray + def containsBlock(blockId: BlockId): Boolean = { + blockId match { + case RDDBlockId(rddId, _) => + _rddBlocks.get(rddId).exists(_.contains(blockId)) + case _ => + _nonRddBlocks.contains(blockId) + } } - /** Returns storage information of all RDDs persisted in the given SparkContext. */ - def rddInfoFromStorageStatus( - storageStatuses: Seq[StorageStatus], - sc: SparkContext): Array[RDDInfo] = { - rddInfoFromStorageStatus(storageStatuses, rddInfoFromSparkContext(sc)) + /** + * Return the given block stored in this block manager in O(1) time. + * Note that this is much faster than `this.blocks.get`, which is O(blocks) time. + */ + def getBlock(blockId: BlockId): Option[BlockStatus] = { + blockId match { + case RDDBlockId(rddId, _) => + _rddBlocks.get(rddId).map(_.get(blockId)).flatten + case _ => + _nonRddBlocks.get(blockId) + } } - /** Returns storage information of all RDDs in the given list. */ - def rddInfoFromStorageStatus( - storageStatuses: Seq[StorageStatus], - rddInfos: Seq[RDDInfo], - updatedBlocks: Seq[(BlockId, BlockStatus)] = Seq.empty): Array[RDDInfo] = { - - // Mapping from a block ID -> its status - val blockMap = mutable.Map(storageStatuses.flatMap(_.rddBlocks): _*) - - // Record updated blocks, if any - updatedBlocks - .collect { case (id: RDDBlockId, status) => (id, status) } - .foreach { case (id, status) => blockMap(id) = status } - - // Mapping from RDD ID -> an array of associated BlockStatuses - val rddBlockMap = blockMap - .groupBy { case (k, _) => k.rddId } - .mapValues(_.values.toArray) - - // Mapping from RDD ID -> the associated RDDInfo (with potentially outdated storage information) - val rddInfoMap = rddInfos.map { info => (info.id, info) }.toMap - - val rddStorageInfos = rddBlockMap.flatMap { case (rddId, blocks) => - // Add up memory, disk and Tachyon sizes - val persistedBlocks = - blocks.filter { status => status.memSize + status.diskSize + status.tachyonSize > 0 } - val _storageLevel = - if (persistedBlocks.length > 0) persistedBlocks(0).storageLevel else StorageLevel.NONE - val memSize = persistedBlocks.map(_.memSize).reduceOption(_ + _).getOrElse(0L) - val diskSize = persistedBlocks.map(_.diskSize).reduceOption(_ + _).getOrElse(0L) - val tachyonSize = persistedBlocks.map(_.tachyonSize).reduceOption(_ + _).getOrElse(0L) - rddInfoMap.get(rddId).map { rddInfo => - rddInfo.storageLevel = _storageLevel - rddInfo.numCachedPartitions = persistedBlocks.length - rddInfo.memSize = memSize - rddInfo.diskSize = diskSize - rddInfo.tachyonSize = tachyonSize - rddInfo - } - }.toArray + /** + * Return the number of blocks stored in this block manager in O(RDDs) time. + * Note that this is much faster than `this.blocks.size`, which is O(blocks) time. + */ + def numBlocks: Int = _nonRddBlocks.size + numRddBlocks + + /** + * Return the number of RDD blocks stored in this block manager in O(RDDs) time. + * Note that this is much faster than `this.rddBlocks.size`, which is O(RDD blocks) time. + */ + def numRddBlocks: Int = _rddBlocks.values.map(_.size).sum - scala.util.Sorting.quickSort(rddStorageInfos) - rddStorageInfos + /** + * Return the number of blocks that belong to the given RDD in O(1) time. + * Note that this is much faster than `this.rddBlocksById(rddId).size`, which is + * O(blocks in this RDD) time. + */ + def numRddBlocksById(rddId: Int): Int = _rddBlocks.get(rddId).map(_.size).getOrElse(0) + + /** Return the memory remaining in this block manager. */ + def memRemaining: Long = maxMem - memUsed + + /** Return the memory used by this block manager. */ + def memUsed: Long = _nonRddStorageInfo._1 + _rddBlocks.keys.toSeq.map(memUsedByRdd).sum + + /** Return the disk space used by this block manager. */ + def diskUsed: Long = _nonRddStorageInfo._2 + _rddBlocks.keys.toSeq.map(diskUsedByRdd).sum + + /** Return the off-heap space used by this block manager. */ + def offHeapUsed: Long = _nonRddStorageInfo._3 + _rddBlocks.keys.toSeq.map(offHeapUsedByRdd).sum + + /** Return the memory used by the given RDD in this block manager in O(1) time. */ + def memUsedByRdd(rddId: Int): Long = _rddStorageInfo.get(rddId).map(_._1).getOrElse(0L) + + /** Return the disk space used by the given RDD in this block manager in O(1) time. */ + def diskUsedByRdd(rddId: Int): Long = _rddStorageInfo.get(rddId).map(_._2).getOrElse(0L) + + /** Return the off-heap space used by the given RDD in this block manager in O(1) time. */ + def offHeapUsedByRdd(rddId: Int): Long = _rddStorageInfo.get(rddId).map(_._3).getOrElse(0L) + + /** Return the storage level, if any, used by the given RDD in this block manager. */ + def rddStorageLevel(rddId: Int): Option[StorageLevel] = _rddStorageInfo.get(rddId).map(_._4) + + /** + * Update the relevant storage info, taking into account any existing status for this block. + */ + private def updateStorageInfo(blockId: BlockId, newBlockStatus: BlockStatus): Unit = { + val oldBlockStatus = getBlock(blockId).getOrElse(BlockStatus.empty) + val changeInMem = newBlockStatus.memSize - oldBlockStatus.memSize + val changeInDisk = newBlockStatus.diskSize - oldBlockStatus.diskSize + val changeInTachyon = newBlockStatus.tachyonSize - oldBlockStatus.tachyonSize + val level = newBlockStatus.storageLevel + + // Compute new info from old info + val (oldMem, oldDisk, oldTachyon) = blockId match { + case RDDBlockId(rddId, _) => + _rddStorageInfo.get(rddId) + .map { case (mem, disk, tachyon, _) => (mem, disk, tachyon) } + .getOrElse((0L, 0L, 0L)) + case _ => + _nonRddStorageInfo + } + val newMem = math.max(oldMem + changeInMem, 0L) + val newDisk = math.max(oldDisk + changeInDisk, 0L) + val newTachyon = math.max(oldTachyon + changeInTachyon, 0L) + + // Set the correct info + blockId match { + case RDDBlockId(rddId, _) => + // If this RDD is no longer persisted, remove it + if (newMem + newDisk + newTachyon == 0) { + _rddStorageInfo.remove(rddId) + } else { + _rddStorageInfo(rddId) = (newMem, newDisk, newTachyon, level) + } + case _ => + _nonRddStorageInfo = (newMem, newDisk, newTachyon) + } } - /** Returns a mapping from BlockId to the locations of the associated block. */ - def blockLocationsFromStorageStatus( - storageStatuses: Seq[StorageStatus]): Map[BlockId, Seq[String]] = { - val blockLocationPairs = storageStatuses.flatMap { storageStatus => - storageStatus.blocks.map { case (bid, _) => (bid, storageStatus.blockManagerId.hostPort) } +} + +/** Helper methods for storage-related objects. */ +private[spark] object StorageUtils { + + /** + * Update the given list of RDDInfo with the given list of storage statuses. + * This method overwrites the old values stored in the RDDInfo's. + */ + def updateRddInfo(rddInfos: Seq[RDDInfo], statuses: Seq[StorageStatus]): Unit = { + rddInfos.foreach { rddInfo => + val rddId = rddInfo.id + // Assume all blocks belonging to the same RDD have the same storage level + val storageLevel = statuses + .flatMap(_.rddStorageLevel(rddId)).headOption.getOrElse(StorageLevel.NONE) + val numCachedPartitions = statuses.map(_.numRddBlocksById(rddId)).sum + val memSize = statuses.map(_.memUsedByRdd(rddId)).sum + val diskSize = statuses.map(_.diskUsedByRdd(rddId)).sum + val tachyonSize = statuses.map(_.offHeapUsedByRdd(rddId)).sum + + rddInfo.storageLevel = storageLevel + rddInfo.numCachedPartitions = numCachedPartitions + rddInfo.memSize = memSize + rddInfo.diskSize = diskSize + rddInfo.tachyonSize = tachyonSize } - blockLocationPairs.toMap - .groupBy { case (blockId, _) => blockId } - .mapValues(_.values.toSeq) } - /** Filters the given list of StorageStatus by the given RDD ID. */ - def filterStorageStatusByRDD( - storageStatuses: Seq[StorageStatus], - rddId: Int): Array[StorageStatus] = { - storageStatuses.map { status => - val filteredBlocks = status.rddBlocks.filterKeys(_.rddId == rddId).toSeq - val filteredBlockMap = mutable.Map[BlockId, BlockStatus](filteredBlocks: _*) - new StorageStatus(status.blockManagerId, status.maxMem, filteredBlockMap) - }.toArray + /** + * Return a mapping from block ID to its locations for each block that belongs to the given RDD. + */ + def getRddBlockLocations(rddId: Int, statuses: Seq[StorageStatus]): Map[BlockId, Seq[String]] = { + val blockLocations = new mutable.HashMap[BlockId, mutable.ListBuffer[String]] + statuses.foreach { status => + status.rddBlocksById(rddId).foreach { case (bid, _) => + val location = status.blockManagerId.hostPort + blockLocations.getOrElseUpdate(bid, mutable.ListBuffer.empty) += location + } + } + blockLocations } + } diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala index a2535e3c1c41f..29e9cf947856f 100644 --- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala @@ -174,40 +174,32 @@ private[spark] object JettyUtils extends Logging { hostName: String, port: Int, handlers: Seq[ServletContextHandler], - conf: SparkConf): ServerInfo = { + conf: SparkConf, + serverName: String = ""): ServerInfo = { val collection = new ContextHandlerCollection collection.setHandlers(handlers.toArray) addFilters(handlers, conf) - @tailrec + // Bind to the given port, or throw a java.net.BindException if the port is occupied def connect(currentPort: Int): (Server, Int) = { val server = new Server(new InetSocketAddress(hostName, currentPort)) val pool = new QueuedThreadPool pool.setDaemon(true) server.setThreadPool(pool) server.setHandler(collection) - - Try { + try { server.start() - } match { - case s: Success[_] => - (server, server.getConnectors.head.getLocalPort) - case f: Failure[_] => - val nextPort = (currentPort + 1) % 65536 + (server, server.getConnectors.head.getLocalPort) + } catch { + case e: Exception => server.stop() pool.stop() - val msg = s"Failed to create UI on port $currentPort. Trying again on port $nextPort." - if (f.toString.contains("Address already in use")) { - logWarning(s"$msg - $f") - } else { - logError(msg, f.exception) - } - connect(nextPort) + throw e } } - val (server, boundPort) = connect(port) + val (server, boundPort) = Utils.startServiceOnPort[Server](port, connect, serverName) ServerInfo(server, boundPort, collection) } diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala index 097a1b81e1dd1..6c788a37dc70b 100644 --- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala @@ -36,7 +36,7 @@ private[spark] class SparkUI( val listenerBus: SparkListenerBus, var appName: String, val basePath: String = "") - extends WebUI(securityManager, SparkUI.getUIPort(conf), conf, basePath) + extends WebUI(securityManager, SparkUI.getUIPort(conf), conf, basePath, "SparkUI") with Logging { def this(sc: SparkContext) = this(sc, sc.conf, sc.env.securityManager, sc.listenerBus, sc.appName) diff --git a/core/src/main/scala/org/apache/spark/ui/WebUI.scala b/core/src/main/scala/org/apache/spark/ui/WebUI.scala index 856273e1d4e21..5f52f95088007 100644 --- a/core/src/main/scala/org/apache/spark/ui/WebUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/WebUI.scala @@ -39,7 +39,8 @@ private[spark] abstract class WebUI( securityManager: SecurityManager, port: Int, conf: SparkConf, - basePath: String = "") + basePath: String = "", + name: String = "") extends Logging { protected val tabs = ArrayBuffer[WebUITab]() @@ -97,7 +98,7 @@ private[spark] abstract class WebUI( def bind() { assert(!serverInfo.isDefined, "Attempted to bind %s more than once!".format(className)) try { - serverInfo = Some(startJettyServer("0.0.0.0", port, handlers, conf)) + serverInfo = Some(startJettyServer("0.0.0.0", port, handlers, conf, name)) logInfo("Started %s at http://%s:%d".format(className, publicHostName, boundPort)) } catch { case e: Exception => diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala index b358c855e1c88..b814b0e6b8509 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala @@ -49,9 +49,9 @@ private[ui] class ExecutorsPage(parent: ExecutorsTab) extends WebUIPage("") { def render(request: HttpServletRequest): Seq[Node] = { val storageStatusList = listener.storageStatusList - val maxMem = storageStatusList.map(_.maxMem).fold(0L)(_ + _) - val memUsed = storageStatusList.map(_.memUsed).fold(0L)(_ + _) - val diskSpaceUsed = storageStatusList.flatMap(_.blocks.values.map(_.diskSize)).fold(0L)(_ + _) + val maxMem = storageStatusList.map(_.maxMem).sum + val memUsed = storageStatusList.map(_.memUsed).sum + val diskUsed = storageStatusList.map(_.diskUsed).sum val execInfo = for (statusId <- 0 until storageStatusList.size) yield getExecInfo(statusId) val execInfoSorted = execInfo.sortBy(_.id) @@ -80,7 +80,7 @@ private[ui] class ExecutorsPage(parent: ExecutorsTab) extends WebUIPage("") { - {execInfoSorted.map(execRow(_))} + {execInfoSorted.map(execRow)} @@ -91,7 +91,7 @@ private[ui] class ExecutorsPage(parent: ExecutorsTab) extends WebUIPage("") {
  • Memory: {Utils.bytesToString(memUsed)} Used ({Utils.bytesToString(maxMem)} Total)
  • -
  • Disk: {Utils.bytesToString(diskSpaceUsed)} Used
  • +
  • Disk: {Utils.bytesToString(diskUsed)} Used
  • @@ -145,7 +145,7 @@ private[ui] class ExecutorsPage(parent: ExecutorsTab) extends WebUIPage("") { val status = listener.storageStatusList(statusId) val execId = status.blockManagerId.executorId val hostPort = status.blockManagerId.hostPort - val rddBlocks = status.blocks.size + val rddBlocks = status.numBlocks val memUsed = status.memUsed val maxMem = status.maxMem val diskUsed = status.diskUsed diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala index efb527b4f03e6..a57a354620163 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala @@ -17,7 +17,7 @@ package org.apache.spark.ui.jobs -import scala.collection.mutable.{HashMap, ListBuffer} +import scala.collection.mutable.{HashMap, ListBuffer, Map} import org.apache.spark._ import org.apache.spark.annotation.DeveloperApi @@ -65,6 +65,10 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { new StageUIData }) + for ((id, info) <- stageCompleted.stageInfo.accumulables) { + stageData.accumulables(id) = info + } + poolToActiveStages.get(stageData.schedulingPool).foreach(_.remove(stageId)) activeStages.remove(stageId) if (stage.failureReason.isEmpty) { @@ -130,32 +134,20 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { new StageUIData }) - // create executor summary map if necessary - val executorSummaryMap = stageData.executorSummary - executorSummaryMap.getOrElseUpdate(key = info.executorId, op = new ExecutorSummary) - - executorSummaryMap.get(info.executorId).foreach { y => - // first update failed-task, succeed-task - taskEnd.reason match { - case Success => - y.succeededTasks += 1 - case _ => - y.failedTasks += 1 - } + for (accumulableInfo <- info.accumulables) { + stageData.accumulables(accumulableInfo.id) = accumulableInfo + } - // update duration - y.taskTime += info.duration + val execSummaryMap = stageData.executorSummary + val execSummary = execSummaryMap.getOrElseUpdate(info.executorId, new ExecutorSummary) - val metrics = taskEnd.taskMetrics - if (metrics != null) { - metrics.inputMetrics.foreach { y.inputBytes += _.bytesRead } - metrics.shuffleReadMetrics.foreach { y.shuffleRead += _.remoteBytesRead } - metrics.shuffleWriteMetrics.foreach { y.shuffleWrite += _.shuffleBytesWritten } - y.memoryBytesSpilled += metrics.memoryBytesSpilled - y.diskBytesSpilled += metrics.diskBytesSpilled - } + taskEnd.reason match { + case Success => + execSummary.succeededTasks += 1 + case _ => + execSummary.failedTasks += 1 } - + execSummary.taskTime += info.duration stageData.numActiveTasks -= 1 val (errorMessage, metrics): (Option[String], Option[TaskMetrics]) = @@ -171,28 +163,75 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { (Some(e.toErrorString), None) } + if (!metrics.isEmpty) { + val oldMetrics = stageData.taskData.get(info.taskId).flatMap(_.taskMetrics) + updateAggregateMetrics(stageData, info.executorId, metrics.get, oldMetrics) + } - val taskRunTime = metrics.map(_.executorRunTime).getOrElse(0L) - stageData.executorRunTime += taskRunTime - val inputBytes = metrics.flatMap(_.inputMetrics).map(_.bytesRead).getOrElse(0L) - stageData.inputBytes += inputBytes - - val shuffleRead = metrics.flatMap(_.shuffleReadMetrics).map(_.remoteBytesRead).getOrElse(0L) - stageData.shuffleReadBytes += shuffleRead - - val shuffleWrite = - metrics.flatMap(_.shuffleWriteMetrics).map(_.shuffleBytesWritten).getOrElse(0L) - stageData.shuffleWriteBytes += shuffleWrite - - val memoryBytesSpilled = metrics.map(_.memoryBytesSpilled).getOrElse(0L) - stageData.memoryBytesSpilled += memoryBytesSpilled + val taskData = stageData.taskData.getOrElseUpdate(info.taskId, new TaskUIData(info)) + taskData.taskInfo = info + taskData.taskMetrics = metrics + taskData.errorMessage = errorMessage + } + } - val diskBytesSpilled = metrics.map(_.diskBytesSpilled).getOrElse(0L) - stageData.diskBytesSpilled += diskBytesSpilled + /** + * Upon receiving new metrics for a task, updates the per-stage and per-executor-per-stage + * aggregate metrics by calculating deltas between the currently recorded metrics and the new + * metrics. + */ + def updateAggregateMetrics( + stageData: StageUIData, + execId: String, + taskMetrics: TaskMetrics, + oldMetrics: Option[TaskMetrics]) { + val execSummary = stageData.executorSummary.getOrElseUpdate(execId, new ExecutorSummary) + + val shuffleWriteDelta = + (taskMetrics.shuffleWriteMetrics.map(_.shuffleBytesWritten).getOrElse(0L) + - oldMetrics.flatMap(_.shuffleWriteMetrics).map(_.shuffleBytesWritten).getOrElse(0L)) + stageData.shuffleWriteBytes += shuffleWriteDelta + execSummary.shuffleWrite += shuffleWriteDelta + + val shuffleReadDelta = + (taskMetrics.shuffleReadMetrics.map(_.remoteBytesRead).getOrElse(0L) + - oldMetrics.flatMap(_.shuffleReadMetrics).map(_.remoteBytesRead).getOrElse(0L)) + stageData.shuffleReadBytes += shuffleReadDelta + execSummary.shuffleRead += shuffleReadDelta + + val diskSpillDelta = + taskMetrics.diskBytesSpilled - oldMetrics.map(_.diskBytesSpilled).getOrElse(0L) + stageData.diskBytesSpilled += diskSpillDelta + execSummary.diskBytesSpilled += diskSpillDelta + + val memorySpillDelta = + taskMetrics.memoryBytesSpilled - oldMetrics.map(_.memoryBytesSpilled).getOrElse(0L) + stageData.memoryBytesSpilled += memorySpillDelta + execSummary.memoryBytesSpilled += memorySpillDelta + + val timeDelta = + taskMetrics.executorRunTime - oldMetrics.map(_.executorRunTime).getOrElse(0L) + stageData.executorRunTime += timeDelta + } - stageData.taskData(info.taskId) = new TaskUIData(info, metrics, errorMessage) + override def onExecutorMetricsUpdate(executorMetricsUpdate: SparkListenerExecutorMetricsUpdate) { + for ((taskId, sid, taskMetrics) <- executorMetricsUpdate.taskMetrics) { + val stageData = stageIdToData.getOrElseUpdate(sid, { + logWarning("Metrics update for task in unknown stage " + sid) + new StageUIData + }) + val taskData = stageData.taskData.get(taskId) + taskData.map { t => + if (!t.taskInfo.finished) { + updateAggregateMetrics(stageData, executorMetricsUpdate.execId, taskMetrics, + t.taskMetrics) + + // Overwrite task metrics + t.taskMetrics = Some(taskMetrics) + } + } } - } // end of onTaskEnd + } override def onEnvironmentUpdate(environmentUpdate: SparkListenerEnvironmentUpdate) { synchronized { diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressTab.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressTab.scala index 3308c8c8a3d37..8a01ec80c9dd6 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressTab.scala @@ -41,7 +41,7 @@ private[ui] class JobProgressTab(parent: SparkUI) extends WebUITab(parent, "stag def isFairScheduler = listener.schedulingMode.exists(_ == SchedulingMode.FAIR) def handleKillRequest(request: HttpServletRequest) = { - if (killEnabled) { + if ((killEnabled) && (parent.securityManager.checkModifyPermissions(request.getRemoteUser))) { val killFlag = Option(request.getParameter("terminate")).getOrElse("false").toBoolean val stageId = Option(request.getParameter("id")).getOrElse("-1").toInt if (stageId >= 0 && killFlag && listener.activeStages.contains(stageId)) { diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index cab26b9e2f7d3..8bc1ba758cf77 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -20,11 +20,12 @@ package org.apache.spark.ui.jobs import java.util.Date import javax.servlet.http.HttpServletRequest -import scala.xml.Node +import scala.xml.{Node, Unparsed} import org.apache.spark.ui.{ToolTips, WebUIPage, UIUtils} import org.apache.spark.ui.jobs.UIData._ import org.apache.spark.util.{Utils, Distribution} +import org.apache.spark.scheduler.AccumulableInfo /** Page showing statistics and task list for a given stage */ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { @@ -51,6 +52,7 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { val tasks = stageData.taskData.values.toSeq.sortBy(_.taskInfo.launchTime) val numCompleted = tasks.count(_.taskInfo.finished) + val accumulables = listener.stageIdToData(stageId).accumulables val hasInput = stageData.inputBytes > 0 val hasShuffleRead = stageData.shuffleReadBytes > 0 val hasShuffleWrite = stageData.shuffleWriteBytes > 0 @@ -95,10 +97,15 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { // scalastyle:on + val accumulableHeaders: Seq[String] = Seq("Accumulable", "Value") + def accumulableRow(acc: AccumulableInfo) = {acc.name}{acc.value} + val accumulableTable = UIUtils.listingTable(accumulableHeaders, accumulableRow, + accumulables.values.toSeq) + val taskHeaders: Seq[String] = Seq( "Index", "ID", "Attempt", "Status", "Locality Level", "Executor", - "Launch Time", "Duration", "GC Time") ++ + "Launch Time", "Duration", "GC Time", "Accumulators") ++ {if (hasInput) Seq("Input") else Nil} ++ {if (hasShuffleRead) Seq("Shuffle Read") else Nil} ++ {if (hasShuffleWrite) Seq("Write Time", "Shuffle Write") else Nil} ++ @@ -208,11 +215,16 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { Some(UIUtils.listingTable(quantileHeaders, quantileRow, listings, fixedWidth = true)) } val executorTable = new ExecutorTable(stageId, parent) + + val maybeAccumulableTable: Seq[Node] = + if (accumulables.size > 0) {

    Accumulators

    ++ accumulableTable } else Seq() + val content = summary ++

    Summary Metrics for {numCompleted} Completed Tasks

    ++
    {summaryTable.getOrElse("No tasks have reported metrics yet.")}
    ++

    Aggregated Metrics by Executor

    ++ executorTable.toNodeSeq ++ + maybeAccumulableTable ++

    Tasks

    ++ taskTable UIUtils.headerSparkPage(content, basePath, appName, "Details for Stage %d".format(stageId), @@ -279,6 +291,11 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { {if (gcTime > 0) UIUtils.formatDuration(gcTime) else ""} + + {Unparsed( + info.accumulables.map{acc => s"${acc.name}: ${acc.update.get}"}.mkString("
    ") + )} + - - Browser - Standalone Cluster Master - 8080 - Web UI - master.ui.port - Jetty-based - - - Browser - Driver - 4040 - Web UI - spark.ui.port - Jetty-based - - - Browser - History Server - 18080 - Web UI - spark.history.ui.port - Jetty-based - - - Browser - Worker - 8081 - Web UI - worker.ui.port - Jetty-based - - - - Application - Standalone Cluster Master - 7077 - Submit job to cluster - spark.driver.port - Akka-based. Set to "0" to choose a port randomly - - - Worker - Standalone Cluster Master - 7077 - Join cluster - spark.driver.port - Akka-based. Set to "0" to choose a port randomly - - - Application - Worker - (random) - Join cluster - SPARK_WORKER_PORT (standalone cluster) - Akka-based - - - - - Driver and other Workers - Worker - (random) - -
      -
    • File server for file and jars
    • -
    • Http Broadcast
    • -
    • Class file server (Spark Shell only)
    • -
    - - None - Jetty-based. Each of these services starts on a random port that cannot be configured - - - +Spark makes heavy use of the network, and some environments have strict requirements for using +tight firewall settings. For a complete list of ports to configure, see the +[security page](security.html#configuring-ports-for-network-security). # High Availability By default, standalone scheduling clusters are resilient to Worker failures (insofar as Spark itself is resilient to losing work by moving it to other workers). However, the scheduler uses a Master to make scheduling decisions, and this (by default) creates a single point of failure: if the Master crashes, no new applications can be created. In order to circumvent this, we have two high availability schemes, detailed below. -## Standby Masters with ZooKeeper +# Standby Masters with ZooKeeper **Overview** @@ -432,7 +347,7 @@ There's an important distinction to be made between "registering with a Master" Due to this property, new Masters can be created at any time, and the only thing you need to worry about is that _new_ applications and Workers can find it to register with in case it becomes the leader. Once registered, you're taken care of. -## Single-Node Recovery with Local File System +# Single-Node Recovery with Local File System **Overview** diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 38728534a46e0..cd6543945c385 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -136,13 +136,13 @@ val sqlContext = new org.apache.spark.sql.SQLContext(sc) import sqlContext.createSchemaRDD // Define the schema using a case class. -// Note: Case classes in Scala 2.10 can support only up to 22 fields. To work around this limit, +// Note: Case classes in Scala 2.10 can support only up to 22 fields. To work around this limit, // you can use custom classes that implement the Product interface. case class Person(name: String, age: Int) // Create an RDD of Person objects and register it as a table. val people = sc.textFile("examples/src/main/resources/people.txt").map(_.split(",")).map(p => Person(p(0), p(1).trim.toInt)) -people.registerAsTable("people") +people.registerTempTable("people") // SQL statements can be run by using the sql methods provided by sqlContext. val teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19") @@ -210,7 +210,7 @@ JavaRDD people = sc.textFile("examples/src/main/resources/people.txt").m // Apply a schema to an RDD of JavaBeans and register it as a table. JavaSchemaRDD schemaPeople = sqlContext.applySchema(people, Person.class); -schemaPeople.registerAsTable("people"); +schemaPeople.registerTempTable("people"); // SQL can be run over RDDs that have been registered as tables. JavaSchemaRDD teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19") @@ -248,7 +248,7 @@ people = parts.map(lambda p: {"name": p[0], "age": int(p[1])}) # In future versions of PySpark we would like to add support for registering RDDs with other # datatypes as tables schemaPeople = sqlContext.inferSchema(people) -schemaPeople.registerAsTable("people") +schemaPeople.registerTempTable("people") # SQL can be run over SchemaRDDs that have been registered as a table. teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19") @@ -292,7 +292,7 @@ people.saveAsParquetFile("people.parquet") val parquetFile = sqlContext.parquetFile("people.parquet") //Parquet files can also be registered as tables and then used in SQL statements. -parquetFile.registerAsTable("parquetFile") +parquetFile.registerTempTable("parquetFile") val teenagers = sqlContext.sql("SELECT name FROM parquetFile WHERE age >= 13 AND age <= 19") teenagers.map(t => "Name: " + t(0)).collect().foreach(println) {% endhighlight %} @@ -314,7 +314,7 @@ schemaPeople.saveAsParquetFile("people.parquet"); JavaSchemaRDD parquetFile = sqlContext.parquetFile("people.parquet"); //Parquet files can also be registered as tables and then used in SQL statements. -parquetFile.registerAsTable("parquetFile"); +parquetFile.registerTempTable("parquetFile"); JavaSchemaRDD teenagers = sqlContext.sql("SELECT name FROM parquetFile WHERE age >= 13 AND age <= 19"); List teenagerNames = teenagers.map(new Function() { public String call(Row row) { @@ -340,7 +340,7 @@ schemaPeople.saveAsParquetFile("people.parquet") parquetFile = sqlContext.parquetFile("people.parquet") # Parquet files can also be registered as tables and then used in SQL statements. -parquetFile.registerAsTable("parquetFile"); +parquetFile.registerTempTable("parquetFile"); teenagers = sqlContext.sql("SELECT name FROM parquetFile WHERE age >= 13 AND age <= 19") teenNames = teenagers.map(lambda p: "Name: " + p.name) for teenName in teenNames.collect(): @@ -378,7 +378,7 @@ people.printSchema() // |-- name: StringType // Register this SchemaRDD as a table. -people.registerAsTable("people") +people.registerTempTable("people") // SQL statements can be run by using the sql methods provided by sqlContext. val teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19") @@ -416,7 +416,7 @@ people.printSchema(); // |-- name: StringType // Register this JavaSchemaRDD as a table. -people.registerAsTable("people"); +people.registerTempTable("people"); // SQL statements can be run by using the sql methods provided by sqlContext. JavaSchemaRDD teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19"); @@ -455,7 +455,7 @@ people.printSchema() # |-- name: StringType # Register this SchemaRDD as a table. -people.registerAsTable("people") +people.registerTempTable("people") # SQL statements can be run by using the sql methods provided by sqlContext. teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19") @@ -487,19 +487,19 @@ Configuration of Hive is done by placing your `hive-site.xml` file in `conf/`. When working with Hive one must construct a `HiveContext`, which inherits from `SQLContext`, and adds support for finding tables in in the MetaStore and writing queries using HiveQL. Users who do -not have an existing Hive deployment can also experiment with the `LocalHiveContext`, -which is similar to `HiveContext`, but creates a local copy of the `metastore` and `warehouse` -automatically. +not have an existing Hive deployment can still create a HiveContext. When not configured by the +hive-site.xml, the context automatically creates `metastore_db` and `warehouse` in the current +directory. {% highlight scala %} // sc is an existing SparkContext. val hiveContext = new org.apache.spark.sql.hive.HiveContext(sc) -hiveContext.hql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") -hiveContext.hql("LOAD DATA LOCAL INPATH 'examples/src/main/resources/kv1.txt' INTO TABLE src") +hiveContext.sql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") +hiveContext.sql("LOAD DATA LOCAL INPATH 'examples/src/main/resources/kv1.txt' INTO TABLE src") // Queries are expressed in HiveQL -hiveContext.hql("FROM src SELECT key, value").collect().foreach(println) +hiveContext.sql("FROM src SELECT key, value").collect().foreach(println) {% endhighlight %} @@ -515,11 +515,11 @@ expressed in HiveQL. // sc is an existing JavaSparkContext. JavaHiveContext hiveContext = new org.apache.spark.sql.hive.api.java.HiveContext(sc); -hiveContext.hql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)"); -hiveContext.hql("LOAD DATA LOCAL INPATH 'examples/src/main/resources/kv1.txt' INTO TABLE src"); +hiveContext.sql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)"); +hiveContext.sql("LOAD DATA LOCAL INPATH 'examples/src/main/resources/kv1.txt' INTO TABLE src"); // Queries are expressed in HiveQL. -Row[] results = hiveContext.hql("FROM src SELECT key, value").collect(); +Row[] results = hiveContext.sql("FROM src SELECT key, value").collect(); {% endhighlight %} @@ -537,18 +537,17 @@ expressed in HiveQL. from pyspark.sql import HiveContext hiveContext = HiveContext(sc) -hiveContext.hql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") -hiveContext.hql("LOAD DATA LOCAL INPATH 'examples/src/main/resources/kv1.txt' INTO TABLE src") +hiveContext.sql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") +hiveContext.sql("LOAD DATA LOCAL INPATH 'examples/src/main/resources/kv1.txt' INTO TABLE src") # Queries can be expressed in HiveQL. -results = hiveContext.hql("FROM src SELECT key, value").collect() +results = hiveContext.sql("FROM src SELECT key, value").collect() {% endhighlight %} - # Writing Language-Integrated Relational Queries **Language-Integrated queries are currently only supported in Scala.** @@ -573,4 +572,210 @@ prefixed with a tick (`'`). Implicit conversions turn these symbols into expres evaluated by the SQL execution engine. A full list of the functions supported can be found in the [ScalaDoc](api/scala/index.html#org.apache.spark.sql.SchemaRDD). - \ No newline at end of file + + +## Running the Thrift JDBC server + +The Thrift JDBC server implemented here corresponds to the [`HiveServer2`] +(https://cwiki.apache.org/confluence/display/Hive/Setting+Up+HiveServer2) in Hive 0.12. You can test +the JDBC server with the beeline script comes with either Spark or Hive 0.12. In order to use Hive +you must first run '`sbt/sbt -Phive-thriftserver assembly/assembly`' (or use `-Phive-thriftserver` +for maven). + +To start the JDBC server, run the following in the Spark directory: + + ./sbin/start-thriftserver.sh + +The default port the server listens on is 10000. To listen on customized host and port, please set +the `HIVE_SERVER2_THRIFT_PORT` and `HIVE_SERVER2_THRIFT_BIND_HOST` environment variables. You may +run `./sbin/start-thriftserver.sh --help` for a complete list of all available options. Now you can +use beeline to test the Thrift JDBC server: + + ./bin/beeline + +Connect to the JDBC server in beeline with: + + beeline> !connect jdbc:hive2://localhost:10000 + +Beeline will ask you for a username and password. In non-secure mode, simply enter the username on +your machine and a blank password. For secure mode, please follow the instructions given in the +[beeline documentation](https://cwiki.apache.org/confluence/display/Hive/HiveServer2+Clients) + +Configuration of Hive is done by placing your `hive-site.xml` file in `conf/`. + +You may also use the beeline script comes with Hive. + +### Migration Guide for Shark Users + +#### Reducer number + +In Shark, default reducer number is 1 and is controlled by the property `mapred.reduce.tasks`. Spark +SQL deprecates this property by a new property `spark.sql.shuffle.partitions`, whose default value +is 200. Users may customize this property via `SET`: + +``` +SET spark.sql.shuffle.partitions=10; +SELECT page, count(*) c FROM logs_last_month_cached +GROUP BY page ORDER BY c DESC LIMIT 10; +``` + +You may also put this property in `hive-site.xml` to override the default value. + +For now, the `mapred.reduce.tasks` property is still recognized, and is converted to +`spark.sql.shuffle.partitions` automatically. + +#### Caching + +The `shark.cache` table property no longer exists, and tables whose name end with `_cached` are no +longer automcatically cached. Instead, we provide `CACHE TABLE` and `UNCACHE TABLE` statements to +let user control table caching explicitly: + +``` +CACHE TABLE logs_last_month; +UNCACHE TABLE logs_last_month; +``` + +**NOTE** `CACHE TABLE tbl` is lazy, it only marks table `tbl` as "need to by cached if necessary", +but doesn't actually cache it until a query that touches `tbl` is executed. To force the table to be +cached, you may simply count the table immediately after executing `CACHE TABLE`: + +``` +CACHE TABLE logs_last_month; +SELECT COUNT(1) FROM logs_last_month; +``` + +Several caching related features are not supported yet: + +* User defined partition level cache eviction policy +* RDD reloading +* In-memory cache write through policy + +### Compatibility with Apache Hive + +#### Deploying in Exising Hive Warehouses + +Spark SQL Thrift JDBC server is designed to be "out of the box" compatible with existing Hive +installations. You do not need to modify your existing Hive Metastore or change the data placement +or partitioning of your tables. + +#### Supported Hive Features + +Spark SQL supports the vast majority of Hive features, such as: + +* Hive query statements, including: + * `SELECT` + * `GROUP BY + * `ORDER BY` + * `CLUSTER BY` + * `SORT BY` +* All Hive operators, including: + * Relational operators (`=`, `⇔`, `==`, `<>`, `<`, `>`, `>=`, `<=`, etc) + * Arthimatic operators (`+`, `-`, `*`, `/`, `%`, etc) + * Logical operators (`AND`, `&&`, `OR`, `||`, etc) + * Complex type constructors + * Mathemtatical functions (`sign`, `ln`, `cos`, etc) + * String functions (`instr`, `length`, `printf`, etc) +* User defined functions (UDF) +* User defined aggregation functions (UDAF) +* User defined serialization formats (SerDe's) +* Joins + * `JOIN` + * `{LEFT|RIGHT|FULL} OUTER JOIN` + * `LEFT SEMI JOIN` + * `CROSS JOIN` +* Unions +* Sub queries + * `SELECT col FROM ( SELECT a + b AS col from t1) t2` +* Sampling +* Explain +* Partitioned tables +* All Hive DDL Functions, including: + * `CREATE TABLE` + * `CREATE TABLE AS SELECT` + * `ALTER TABLE` +* Most Hive Data types, including: + * `TINYINT` + * `SMALLINT` + * `INT` + * `BIGINT` + * `BOOLEAN` + * `FLOAT` + * `DOUBLE` + * `STRING` + * `BINARY` + * `TIMESTAMP` + * `ARRAY<>` + * `MAP<>` + * `STRUCT<>` + +#### Unsupported Hive Functionality + +Below is a list of Hive features that we don't support yet. Most of these features are rarely used +in Hive deployments. + +**Major Hive Features** + +* Tables with buckets: bucket is the hash partitioning within a Hive table partition. Spark SQL + doesn't support buckets yet. + +**Esoteric Hive Features** + +* Tables with partitions using different input formats: In Spark SQL, all table partitions need to + have the same input format. +* Non-equi outer join: For the uncommon use case of using outer joins with non-equi join conditions + (e.g. condition "`key < 10`"), Spark SQL will output wrong result for the `NULL` tuple. +* `UNIONTYPE` +* Unique join +* Single query multi insert +* Column statistics collecting: Spark SQL does not piggyback scans to collect column statistics at + the moment. + +**Hive Input/Output Formats** + +* File format for CLI: For results showing back to the CLI, Spark SQL only supports TextOutputFormat. +* Hadoop archive + +**Hive Optimizations** + +A handful of Hive optimizations are not yet included in Spark. Some of these (such as indexes) are +not necessary due to Spark SQL's in-memory computational model. Others are slotted for future +releases of Spark SQL. + +* Block level bitmap indexes and virtual columns (used to build indexes) +* Automatically convert a join to map join: For joining a large table with multiple small tables, + Hive automatically converts the join into a map join. We are adding this auto conversion in the + next release. +* Automatically determine the number of reducers for joins and groupbys: Currently in Spark SQL, you + need to control the degree of parallelism post-shuffle using "SET + spark.sql.shuffle.partitions=[num_tasks];". We are going to add auto-setting of parallelism in the + next release. +* Meta-data only query: For queries that can be answered by using only meta data, Spark SQL still + launches tasks to compute the result. +* Skew data flag: Spark SQL does not follow the skew data flags in Hive. +* `STREAMTABLE` hint in join: Spark SQL does not follow the `STREAMTABLE` hint. +* Merge multiple small files for query results: if the result output contains multiple small files, + Hive can optionally merge the small files into fewer large files to avoid overflowing the HDFS + metadata. Spark SQL does not support that. + +## Running the Spark SQL CLI + +The Spark SQL CLI is a convenient tool to run the Hive metastore service in local mode and execute +queries input from command line. Note: the Spark SQL CLI cannot talk to the Thrift JDBC server. + +To start the Spark SQL CLI, run the following in the Spark directory: + + ./bin/spark-sql + +Configuration of Hive is done by placing your `hive-site.xml` file in `conf/`. +You may run `./bin/spark-sql --help` for a complete list of all available +options. + +# Cached tables + +Spark SQL can cache tables using an in-memory columnar format by calling `cacheTable("tableName")`. +Then Spark SQL will scan only required columns and will automatically tune compression to minimize +memory usage and GC pressure. You can call `uncacheTable("tableName")` to remove the table from memory. + +Note that if you just call `cache` rather than `cacheTable`, tables will _not_ be cached in +in-memory columnar format. So we strongly recommend using `cacheTable` whenever you want to +cache tables. diff --git a/docs/streaming-custom-receivers.md b/docs/streaming-custom-receivers.md index a2dc3a8961dfc..1e045a3dd0ca9 100644 --- a/docs/streaming-custom-receivers.md +++ b/docs/streaming-custom-receivers.md @@ -4,7 +4,7 @@ title: Spark Streaming Custom Receivers --- Spark Streaming can receive streaming data from any arbitrary data source beyond -the one's for which it has in-built support (that is, beyond Flume, Kafka, files, sockets, etc.). +the one's for which it has in-built support (that is, beyond Flume, Kafka, Kinesis, files, sockets, etc.). This requires the developer to implement a *receiver* that is customized for receiving data from the concerned data source. This guide walks through the process of implementing a custom receiver and using it in a Spark Streaming application. @@ -174,7 +174,7 @@ val words = lines.flatMap(_.split(" ")) ... {% endhighlight %} -The full source code is in the example [CustomReceiver.scala](https://github.com/apache/spark/blob/master/examples/src/main/scala/org/apache/spark/streaming/examples/CustomReceiver.scala). +The full source code is in the example [CustomReceiver.scala](https://github.com/apache/spark/blob/master/examples/src/main/scala/org/apache/spark/examples/streaming/CustomReceiver.scala).
    diff --git a/docs/streaming-kinesis.md b/docs/streaming-kinesis.md new file mode 100644 index 0000000000000..801c905c88df8 --- /dev/null +++ b/docs/streaming-kinesis.md @@ -0,0 +1,58 @@ +--- +layout: global +title: Spark Streaming Kinesis Receiver +--- + +### Kinesis +Build notes: +
  • Spark supports a Kinesis Streaming Receiver which is not included in the default build due to licensing restrictions.
  • +
  • _**Note that by embedding this library you will include [ASL](https://aws.amazon.com/asl/)-licensed code in your Spark package**_.
  • +
  • The Spark Kinesis Streaming Receiver source code, examples, tests, and artifacts live in $SPARK_HOME/extras/kinesis-asl.
  • +
  • To build with Kinesis, you must run the maven or sbt builds with -Pkinesis-asl`.
  • +
  • Applications will need to link to the 'spark-streaming-kinesis-asl` artifact.
  • + +Kinesis examples notes: +
  • To build the Kinesis examples, you must run the maven or sbt builds with -Pkinesis-asl`.
  • +
  • These examples automatically determine the number of local threads and KinesisReceivers to spin up based on the number of shards for the stream.
  • +
  • KinesisWordCountProducerASL will generate random data to put onto the Kinesis stream for testing.
  • +
  • Checkpointing is disabled (no checkpoint dir is set). The examples as written will not recover from a driver failure.
  • + +Deployment and runtime notes: +
  • A single KinesisReceiver can process many shards of a stream.
  • +
  • Each shard of a stream is processed by one or more KinesisReceiver's managed by the Kinesis Client Library (KCL) Worker.
  • +
  • You never need more KinesisReceivers than the number of shards in your stream.
  • +
  • You can horizontally scale the receiving by creating more KinesisReceiver/DStreams (up to the number of shards for a given stream)
  • +
  • The Kinesis libraries must be present on all worker nodes, as they will need access to the Kinesis Client Library.
  • +
  • This code uses the DefaultAWSCredentialsProviderChain and searches for credentials in the following order of precedence:
    + 1) Environment Variables - AWS_ACCESS_KEY_ID and AWS_SECRET_KEY
    + 2) Java System Properties - aws.accessKeyId and aws.secretKey
    + 3) Credential profiles file - default location (~/.aws/credentials) shared by all AWS SDKs
    + 4) Instance profile credentials - delivered through the Amazon EC2 metadata service
    +
  • +
  • You need to setup a Kinesis stream with 1 or more shards per the following:
    + http://docs.aws.amazon.com/kinesis/latest/dev/step-one-create-stream.html
  • +
  • Valid Kinesis endpoint urls can be found here: Valid endpoint urls: http://docs.aws.amazon.com/general/latest/gr/rande.html#ak_region
  • +
  • When you first start up the KinesisReceiver, the Kinesis Client Library (KCL) needs ~30s to establish connectivity with the AWS Kinesis service, +retrieve any checkpoint data, and negotiate with other KCL's reading from the same stream.
  • +
  • Be careful when changing the app name. Kinesis maintains a mapping table in DynamoDB based on this app name (http://docs.aws.amazon.com/kinesis/latest/dev/kinesis-record-processor-implementation-app.html#kinesis-record-processor-initialization). +Changing the app name could lead to Kinesis errors as only 1 logical application can process a stream. In order to start fresh, +it's always best to delete the DynamoDB table that matches your app name. This DynamoDB table lives in us-east-1 regardless of the Kinesis endpoint URL.
  • + +Failure recovery notes: +
  • The combination of Spark Streaming and Kinesis creates 3 different checkpoints as follows:
    + 1) RDD data checkpoint (Spark Streaming) - frequency is configurable with DStream.checkpoint(Duration)
    + 2) RDD metadata checkpoint (Spark Streaming) - frequency is every DStream batch
    + 3) Kinesis checkpointing (Kinesis) - frequency is controlled by the developer calling ICheckpointer.checkpoint() directly
    +
  • +
  • Checkpointing too frequently will cause excess load on the AWS checkpoint storage layer and may lead to AWS throttling
  • +
  • Upon startup, a KinesisReceiver will begin processing records with sequence numbers greater than the last checkpoint sequence number recorded per shard.
  • +
  • If no checkpoint info exists, the worker will start either from the oldest record available (InitialPositionInStream.TRIM_HORIZON) +or from the tip/latest (InitialPostitionInStream.LATEST). This is configurable.
  • +
  • When pulling from the stream tip (InitialPositionInStream.LATEST), only new stream data will be picked up after the KinesisReceiver starts.
  • +
  • InitialPositionInStream.LATEST could lead to missed records if data is added to the stream while no KinesisReceivers are running.
  • +
  • In production, you'll want to switch to InitialPositionInStream.TRIM_HORIZON which will read up to 24 hours (Kinesis limit) of previous stream data +depending on the checkpoint frequency.
  • +
  • InitialPositionInStream.TRIM_HORIZON may lead to duplicate processing of records depending on the checkpoint frequency.
  • +
  • Record processing should be idempotent when possible.
  • +
  • Failed or latent KinesisReceivers will be detected and automatically shutdown/load-balanced by the KCL.
  • +
  • If possible, explicitly shutdown the worker if a failure occurs in order to trigger the final checkpoint.
  • diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index 90a0eef60c200..9f331ed50d2a4 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -9,7 +9,7 @@ title: Spark Streaming Programming Guide # Overview Spark Streaming is an extension of the core Spark API that allows enables high-throughput, fault-tolerant stream processing of live data streams. Data can be ingested from many sources -like Kafka, Flume, Twitter, ZeroMQ or plain old TCP sockets and be processed using complex +like Kafka, Flume, Twitter, ZeroMQ, Kinesis or plain old TCP sockets and be processed using complex algorithms expressed with high-level functions like `map`, `reduce`, `join` and `window`. Finally, processed data can be pushed out to filesystems, databases, and live dashboards. In fact, you can apply Spark's in-built @@ -38,7 +38,7 @@ stream of results in batches. Spark Streaming provides a high-level abstraction called *discretized stream* or *DStream*, which represents a continuous stream of data. DStreams can be created either from input data -stream from sources such as Kafka and Flume, or by applying high-level +stream from sources such as Kafka, Flume, and Kinesis, or by applying high-level operations on other DStreams. Internally, a DStream is represented as a sequence of [RDDs](api/scala/index.html#org.apache.spark.rdd.RDD). @@ -313,7 +313,7 @@ To write your own Spark Streaming program, you will have to add the following de artifactId = spark-streaming_{{site.SCALA_BINARY_VERSION}} version = {{site.SPARK_VERSION}} -For ingesting data from sources like Kafka and Flume that are not present in the Spark +For ingesting data from sources like Kafka, Flume, and Kinesis that are not present in the Spark Streaming core API, you will have to add the corresponding artifact `spark-streaming-xyz_{{site.SCALA_BINARY_VERSION}}` to the dependencies. For example, @@ -327,6 +327,7 @@ some of the common ones are as follows. Twitter spark-streaming-twitter_{{site.SCALA_BINARY_VERSION}} ZeroMQ spark-streaming-zeromq_{{site.SCALA_BINARY_VERSION}} MQTT spark-streaming-mqtt_{{site.SCALA_BINARY_VERSION}} + Kinesis
    (built separately) kinesis-asl_{{site.SCALA_BINARY_VERSION}} @@ -442,7 +443,7 @@ see the API documentations of the relevant functions in Scala and [JavaStreamingContext](api/scala/index.html#org.apache.spark.streaming.api.java.JavaStreamingContext) for Java. -Additional functionality for creating DStreams from sources such as Kafka, Flume, and Twitter +Additional functionality for creating DStreams from sources such as Kafka, Flume, Kinesis, and Twitter can be imported by adding the right dependencies as explained in an [earlier](#linking) section. To take the case of Kafka, after adding the artifact `spark-streaming-kafka_{{site.SCALA_BINARY_VERSION}}` to the @@ -467,6 +468,9 @@ For more details on these additional sources, see the corresponding [API documen Furthermore, you can also implement your own custom receiver for your sources. See the [Custom Receiver Guide](streaming-custom-receivers.html). +### Kinesis +[Kinesis](streaming-kinesis.html) + ## Operations There are two kinds of DStream operations - _transformations_ and _output operations_. Similar to RDD transformations, DStream transformations operate on one or more DStreams to create new DStreams @@ -939,7 +943,7 @@ Receiving multiple data streams can therefore be achieved by creating multiple i and configuring them to receive different partitions of the data stream from the source(s). For example, a single Kafka input stream receiving two topics of data can be split into two Kafka input streams, each receiving only one topic. This would run two receivers on two workers, -thus allowing data to received in parallel, and increasing overall throughput. +thus allowing data to be received in parallel, and increasing overall throughput. Another parameter that should be considered is the receiver's blocking interval. For most receivers, the received data is coalesced together into large blocks of data before storing inside Spark's memory. @@ -980,7 +984,7 @@ If the number of tasks launched per second is high (say, 50 or more per second), of sending out tasks to the slaves maybe significant and will make it hard to achieve sub-second latencies. The overhead can be reduced by the following changes: -* **Task Serialization**: Using Kryo serialization for serializing tasks can reduced the task +* **Task Serialization**: Using Kryo serialization for serializing tasks can reduce the task sizes, and therefore reduce the time taken to send them to the slaves. * **Execution mode**: Running Spark in Standalone mode or coarse-grained Mesos mode leads to diff --git a/docs/tuning.md b/docs/tuning.md index 4917c11bc1147..8fb2a0433b1a8 100644 --- a/docs/tuning.md +++ b/docs/tuning.md @@ -32,7 +32,7 @@ in your operations) and performance. It provides two serialization libraries: [`java.io.Externalizable`](http://docs.oracle.com/javase/6/docs/api/java/io/Externalizable.html). Java serialization is flexible but often quite slow, and leads to large serialized formats for many classes. -* [Kryo serialization](http://code.google.com/p/kryo/): Spark can also use +* [Kryo serialization](https://github.com/EsotericSoftware/kryo): Spark can also use the Kryo library (version 2) to serialize objects more quickly. Kryo is significantly faster and more compact than Java serialization (often as much as 10x), but does not support all `Serializable` types and requires you to *register* the classes you'll use in the program in advance @@ -68,7 +68,7 @@ conf.set("spark.kryo.registrator", "mypackage.MyRegistrator") val sc = new SparkContext(conf) {% endhighlight %} -The [Kryo documentation](http://code.google.com/p/kryo/) describes more advanced +The [Kryo documentation](https://github.com/EsotericSoftware/kryo) describes more advanced registration options, such as adding custom serialization code. If your objects are large, you may also need to increase the `spark.kryoserializer.buffer.mb` diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py index 02cfe4ec39c7d..0c2f85a3868f4 100755 --- a/ec2/spark_ec2.py +++ b/ec2/spark_ec2.py @@ -135,6 +135,10 @@ def parse_args(): "--master-opts", type="string", default="", help="Extra options to give to master through SPARK_MASTER_OPTS variable " + "(e.g -Dspark.worker.timeout=180)") + parser.add_option( + "--user-data", type="string", default="", + help="Path to a user-data file (most AMI's interpret this as an initialization script)") + (opts, args) = parser.parse_args() if len(args) != 2: @@ -274,6 +278,12 @@ def launch_cluster(conn, opts, cluster_name): if opts.key_pair is None: print >> stderr, "ERROR: Must provide a key pair name (-k) to use on instances." sys.exit(1) + + user_data_content = None + if opts.user_data: + with open(opts.user_data) as user_data_file: + user_data_content = user_data_file.read() + print "Setting up security groups..." master_group = get_or_make_group(conn, cluster_name + "-master") slave_group = get_or_make_group(conn, cluster_name + "-slaves") @@ -347,7 +357,8 @@ def launch_cluster(conn, opts, cluster_name): key_name=opts.key_pair, security_groups=[slave_group], instance_type=opts.instance_type, - block_device_map=block_map) + block_device_map=block_map, + user_data=user_data_content) my_req_ids += [req.id for req in slave_reqs] i += 1 @@ -398,7 +409,8 @@ def launch_cluster(conn, opts, cluster_name): placement=zone, min_count=num_slaves_this_zone, max_count=num_slaves_this_zone, - block_device_map=block_map) + block_device_map=block_map, + user_data=user_data_content) slave_nodes += slave_res.instances print "Launched %d slaves in %s, regid = %s" % (num_slaves_this_zone, zone, slave_res.id) diff --git a/examples/pom.xml b/examples/pom.xml index bd1c387c2eb91..8c4c128bb484d 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -28,12 +28,25 @@ org.apache.spark spark-examples_2.10 - examples + examples jar Spark Project Examples http://spark.apache.org/ + + + kinesis-asl + + + org.apache.spark + spark-streaming-kinesis-asl_${scala.binary.version} + ${project.version} + + + + + org.apache.spark diff --git a/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java b/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java index 607df3eddd550..898297dc658ba 100644 --- a/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java +++ b/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java @@ -74,7 +74,7 @@ public Person call(String line) throws Exception { // Apply a schema to an RDD of Java Beans and register it as a table. JavaSchemaRDD schemaPeople = sqlCtx.applySchema(people, Person.class); - schemaPeople.registerAsTable("people"); + schemaPeople.registerTempTable("people"); // SQL can be run over RDDs that have been registered as tables. JavaSchemaRDD teenagers = sqlCtx.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19"); @@ -100,7 +100,7 @@ public String call(Row row) { JavaSchemaRDD parquetFile = sqlCtx.parquetFile("people.parquet"); //Parquet files can also be registered as tables and then used in SQL statements. - parquetFile.registerAsTable("parquetFile"); + parquetFile.registerTempTable("parquetFile"); JavaSchemaRDD teenagers2 = sqlCtx.sql("SELECT name FROM parquetFile WHERE age >= 13 AND age <= 19"); teenagerNames = teenagers2.map(new Function() { @@ -128,7 +128,7 @@ public String call(Row row) { // |-- name: StringType // Register this JavaSchemaRDD as a table. - peopleFromJsonFile.registerAsTable("people"); + peopleFromJsonFile.registerTempTable("people"); // SQL statements can be run by using the sql methods provided by sqlCtx. JavaSchemaRDD teenagers3 = sqlCtx.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19"); @@ -158,7 +158,7 @@ public String call(Row row) { // | |-- state: StringType // |-- name: StringType - peopleFromJsonRDD.registerAsTable("people2"); + peopleFromJsonRDD.registerTempTable("people2"); JavaSchemaRDD peopleWithCity = sqlCtx.sql("SELECT name, address.city FROM people2"); List nameAndCity = peopleWithCity.map(new Function() { diff --git a/examples/src/main/python/cassandra_outputformat.py b/examples/src/main/python/cassandra_outputformat.py new file mode 100644 index 0000000000000..1dfbf98604425 --- /dev/null +++ b/examples/src/main/python/cassandra_outputformat.py @@ -0,0 +1,83 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import sys + +from pyspark import SparkContext + +""" +Create data in Cassandra fist +(following: https://wiki.apache.org/cassandra/GettingStarted) + +cqlsh> CREATE KEYSPACE test + ... WITH REPLICATION = { 'class' : 'SimpleStrategy', 'replication_factor' : 1 }; +cqlsh> use test; +cqlsh:test> CREATE TABLE users ( + ... user_id int PRIMARY KEY, + ... fname text, + ... lname text + ... ); + +> cassandra_outputformat test users 1745 john smith +> cassandra_outputformat test users 1744 john doe +> cassandra_outputformat test users 1746 john smith + +cqlsh:test> SELECT * FROM users; + + user_id | fname | lname +---------+-------+------- + 1745 | john | smith + 1744 | john | doe + 1746 | john | smith +""" +if __name__ == "__main__": + if len(sys.argv) != 7: + print >> sys.stderr, """ + Usage: cassandra_outputformat + + Run with example jar: + ./bin/spark-submit --driver-class-path /path/to/example/jar /path/to/examples/cassandra_outputformat.py + Assumes you have created the following table in Cassandra already, + running on , in . + + cqlsh:> CREATE TABLE ( + ... user_id int PRIMARY KEY, + ... fname text, + ... lname text + ... ); + """ + exit(-1) + + host = sys.argv[1] + keyspace = sys.argv[2] + cf = sys.argv[3] + sc = SparkContext(appName="CassandraOutputFormat") + + conf = {"cassandra.output.thrift.address":host, + "cassandra.output.thrift.port":"9160", + "cassandra.output.keyspace":keyspace, + "cassandra.output.partitioner.class":"Murmur3Partitioner", + "cassandra.output.cql":"UPDATE " + keyspace + "." + cf + " SET fname = ?, lname = ?", + "mapreduce.output.basename":cf, + "mapreduce.outputformat.class":"org.apache.cassandra.hadoop.cql3.CqlOutputFormat", + "mapreduce.job.output.key.class":"java.util.Map", + "mapreduce.job.output.value.class":"java.util.List"} + key = {"user_id" : int(sys.argv[4])} + sc.parallelize([(key, sys.argv[5:])]).saveAsNewAPIHadoopDataset( + conf=conf, + keyConverter="org.apache.spark.examples.pythonconverters.ToCassandraCQLKeyConverter", + valueConverter="org.apache.spark.examples.pythonconverters.ToCassandraCQLValueConverter") diff --git a/examples/src/main/python/hbase_inputformat.py b/examples/src/main/python/hbase_inputformat.py index 3289d9880a0f5..c9fa8e171c2a1 100644 --- a/examples/src/main/python/hbase_inputformat.py +++ b/examples/src/main/python/hbase_inputformat.py @@ -65,7 +65,8 @@ "org.apache.hadoop.hbase.mapreduce.TableInputFormat", "org.apache.hadoop.hbase.io.ImmutableBytesWritable", "org.apache.hadoop.hbase.client.Result", - valueConverter="org.apache.spark.examples.pythonconverters.HBaseConverter", + keyConverter="org.apache.spark.examples.pythonconverters.ImmutableBytesWritableToStringConverter", + valueConverter="org.apache.spark.examples.pythonconverters.HBaseResultToStringConverter", conf=conf) output = hbase_rdd.collect() for (k, v) in output: diff --git a/examples/src/main/python/hbase_outputformat.py b/examples/src/main/python/hbase_outputformat.py new file mode 100644 index 0000000000000..5e11548fd13f7 --- /dev/null +++ b/examples/src/main/python/hbase_outputformat.py @@ -0,0 +1,65 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import sys + +from pyspark import SparkContext + +""" +Create test table in HBase first: + +hbase(main):001:0> create 'test', 'f1' +0 row(s) in 0.7840 seconds + +> hbase_outputformat test row1 f1 q1 value1 +> hbase_outputformat test row2 f1 q1 value2 +> hbase_outputformat test row3 f1 q1 value3 +> hbase_outputformat test row4 f1 q1 value4 + +hbase(main):002:0> scan 'test' +ROW COLUMN+CELL + row1 column=f1:q1, timestamp=1405659615726, value=value1 + row2 column=f1:q1, timestamp=1405659626803, value=value2 + row3 column=f1:q1, timestamp=1405659640106, value=value3 + row4 column=f1:q1, timestamp=1405659650292, value=value4 +4 row(s) in 0.0780 seconds +""" +if __name__ == "__main__": + if len(sys.argv) != 7: + print >> sys.stderr, """ + Usage: hbase_outputformat + + Run with example jar: + ./bin/spark-submit --driver-class-path /path/to/example/jar /path/to/examples/hbase_outputformat.py + Assumes you have created
    with column family in HBase running on already + """ + exit(-1) + + host = sys.argv[1] + table = sys.argv[2] + sc = SparkContext(appName="HBaseOutputFormat") + + conf = {"hbase.zookeeper.quorum": host, + "hbase.mapred.outputtable": table, + "mapreduce.outputformat.class" : "org.apache.hadoop.hbase.mapreduce.TableOutputFormat", + "mapreduce.job.output.key.class" : "org.apache.hadoop.hbase.io.ImmutableBytesWritable", + "mapreduce.job.output.value.class" : "org.apache.hadoop.io.Writable"} + + sc.parallelize([sys.argv[3:]]).map(lambda x: (x[0], x)).saveAsNewAPIHadoopDataset( + conf=conf, + keyConverter="org.apache.spark.examples.pythonconverters.StringToImmutableBytesWritableConverter", + valueConverter="org.apache.spark.examples.pythonconverters.StringListToPutConverter") diff --git a/examples/src/main/python/mllib/decision_tree_runner.py b/examples/src/main/python/mllib/decision_tree_runner.py new file mode 100755 index 0000000000000..8efadb5223f56 --- /dev/null +++ b/examples/src/main/python/mllib/decision_tree_runner.py @@ -0,0 +1,133 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Decision tree classification and regression using MLlib. +""" + +import numpy, os, sys + +from operator import add + +from pyspark import SparkContext +from pyspark.mllib.regression import LabeledPoint +from pyspark.mllib.tree import DecisionTree +from pyspark.mllib.util import MLUtils + + +def getAccuracy(dtModel, data): + """ + Return accuracy of DecisionTreeModel on the given RDD[LabeledPoint]. + """ + seqOp = (lambda acc, x: acc + (x[0] == x[1])) + predictions = dtModel.predict(data.map(lambda x: x.features)) + truth = data.map(lambda p: p.label) + trainCorrect = predictions.zip(truth).aggregate(0, seqOp, add) + if data.count() == 0: + return 0 + return trainCorrect / (0.0 + data.count()) + + +def getMSE(dtModel, data): + """ + Return mean squared error (MSE) of DecisionTreeModel on the given + RDD[LabeledPoint]. + """ + seqOp = (lambda acc, x: acc + numpy.square(x[0] - x[1])) + predictions = dtModel.predict(data.map(lambda x: x.features)) + truth = data.map(lambda p: p.label) + trainMSE = predictions.zip(truth).aggregate(0, seqOp, add) + if data.count() == 0: + return 0 + return trainMSE / (0.0 + data.count()) + + +def reindexClassLabels(data): + """ + Re-index class labels in a dataset to the range {0,...,numClasses-1}. + If all labels in that range already appear at least once, + then the returned RDD is the same one (without a mapping). + Note: If a label simply does not appear in the data, + the index will not include it. + Be aware of this when reindexing subsampled data. + :param data: RDD of LabeledPoint where labels are integer values + denoting labels for a classification problem. + :return: Pair (reindexedData, origToNewLabels) where + reindexedData is an RDD of LabeledPoint with labels in + the range {0,...,numClasses-1}, and + origToNewLabels is a dictionary mapping original labels + to new labels. + """ + # classCounts: class --> # examples in class + classCounts = data.map(lambda x: x.label).countByValue() + numExamples = sum(classCounts.values()) + sortedClasses = sorted(classCounts.keys()) + numClasses = len(classCounts) + # origToNewLabels: class --> index in 0,...,numClasses-1 + if (numClasses < 2): + print >> sys.stderr, \ + "Dataset for classification should have at least 2 classes." + \ + " The given dataset had only %d classes." % numClasses + exit(1) + origToNewLabels = dict([(sortedClasses[i], i) for i in range(0, numClasses)]) + + print "numClasses = %d" % numClasses + print "Per-class example fractions, counts:" + print "Class\tFrac\tCount" + for c in sortedClasses: + frac = classCounts[c] / (numExamples + 0.0) + print "%g\t%g\t%d" % (c, frac, classCounts[c]) + + if (sortedClasses[0] == 0 and sortedClasses[-1] == numClasses - 1): + return (data, origToNewLabels) + else: + reindexedData = \ + data.map(lambda x: LabeledPoint(origToNewLabels[x.label], x.features)) + return (reindexedData, origToNewLabels) + + +def usage(): + print >> sys.stderr, \ + "Usage: decision_tree_runner [libsvm format data filepath]\n" + \ + " Note: This only supports binary classification." + exit(1) + + +if __name__ == "__main__": + if len(sys.argv) > 2: + usage() + sc = SparkContext(appName="PythonDT") + + # Load data. + dataPath = 'data/mllib/sample_libsvm_data.txt' + if len(sys.argv) == 2: + dataPath = sys.argv[1] + if not os.path.isfile(dataPath): + usage() + points = MLUtils.loadLibSVMFile(sc, dataPath) + + # Re-index class labels if needed. + (reindexedData, origToNewLabels) = reindexClassLabels(points) + + # Train a classifier. + model = DecisionTree.trainClassifier(reindexedData, numClasses=2) + # Print learned tree and stats. + print "Trained DecisionTree for classification:" + print " Model numNodes: %d\n" % model.numNodes() + print " Model depth: %d\n" % model.depth() + print " Training accuracy: %g\n" % getAccuracy(model, reindexedData) + print model diff --git a/examples/src/main/python/mllib/logistic_regression.py b/examples/src/main/python/mllib/logistic_regression.py index 6e0f7a4ee5a81..9d547ff77c984 100755 --- a/examples/src/main/python/mllib/logistic_regression.py +++ b/examples/src/main/python/mllib/logistic_regression.py @@ -30,8 +30,10 @@ from pyspark.mllib.classification import LogisticRegressionWithSGD -# Parse a line of text into an MLlib LabeledPoint object def parsePoint(line): + """ + Parse a line of text into an MLlib LabeledPoint object. + """ values = [float(s) for s in line.split(' ')] if values[0] == -1: # Convert -1 labels to 0 for MLlib values[0] = 0 diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala index 43f13fe24f0d0..cf3d2cca81ff6 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala @@ -21,7 +21,6 @@ import scopt.OptionParser import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.SparkContext._ -import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{DecisionTree, impurity} import org.apache.spark.mllib.tree.configuration.{Algo, Strategy} @@ -33,9 +32,12 @@ import org.apache.spark.rdd.RDD /** * An example runner for decision tree. Run with * {{{ - * ./bin/spark-example org.apache.spark.examples.mllib.DecisionTreeRunner [options] + * ./bin/run-example org.apache.spark.examples.mllib.DecisionTreeRunner [options] * }}} * If you use it as a template to create your own app, please use `spark-submit` to submit your app. + * + * Note: This script treats all features as real-valued (not categorical). + * To include categorical features, modify categoricalFeaturesInfo. */ object DecisionTreeRunner { @@ -48,11 +50,12 @@ object DecisionTreeRunner { case class Params( input: String = null, + dataFormat: String = "libsvm", algo: Algo = Classification, - numClassesForClassification: Int = 2, - maxDepth: Int = 5, + maxDepth: Int = 4, impurity: ImpurityType = Gini, - maxBins: Int = 100) + maxBins: Int = 100, + fracTest: Double = 0.2) def main(args: Array[String]) { val defaultParams = Params() @@ -69,25 +72,31 @@ object DecisionTreeRunner { opt[Int]("maxDepth") .text(s"max depth of the tree, default: ${defaultParams.maxDepth}") .action((x, c) => c.copy(maxDepth = x)) - opt[Int]("numClassesForClassification") - .text(s"number of classes for classification, " - + s"default: ${defaultParams.numClassesForClassification}") - .action((x, c) => c.copy(numClassesForClassification = x)) opt[Int]("maxBins") .text(s"max number of bins, default: ${defaultParams.maxBins}") .action((x, c) => c.copy(maxBins = x)) + opt[Double]("fracTest") + .text(s"fraction of data to hold out for testing, default: ${defaultParams.fracTest}") + .action((x, c) => c.copy(fracTest = x)) + opt[String]("") + .text("data format: libsvm (default), dense (deprecated in Spark v1.1)") + .action((x, c) => c.copy(dataFormat = x)) arg[String]("") .text("input paths to labeled examples in dense format (label,f0 f1 f2 ...)") .required() .action((x, c) => c.copy(input = x)) checkConfig { params => - if (params.algo == Classification && - (params.impurity == Gini || params.impurity == Entropy)) { - success - } else if (params.algo == Regression && params.impurity == Variance) { - success + if (params.fracTest < 0 || params.fracTest > 1) { + failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1].") } else { - failure(s"Algo ${params.algo} is not compatible with impurity ${params.impurity}.") + if (params.algo == Classification && + (params.impurity == Gini || params.impurity == Entropy)) { + success + } else if (params.algo == Regression && params.impurity == Variance) { + success + } else { + failure(s"Algo ${params.algo} is not compatible with impurity ${params.impurity}.") + } } } } @@ -100,16 +109,57 @@ object DecisionTreeRunner { } def run(params: Params) { + val conf = new SparkConf().setAppName("DecisionTreeRunner") val sc = new SparkContext(conf) // Load training data and cache it. - val examples = MLUtils.loadLabeledPoints(sc, params.input).cache() + val origExamples = params.dataFormat match { + case "dense" => MLUtils.loadLabeledPoints(sc, params.input).cache() + case "libsvm" => MLUtils.loadLibSVMFile(sc, params.input).cache() + } + // For classification, re-index classes if needed. + val (examples, numClasses) = params.algo match { + case Classification => { + // classCounts: class --> # examples in class + val classCounts = origExamples.map(_.label).countByValue() + val sortedClasses = classCounts.keys.toList.sorted + val numClasses = classCounts.size + // classIndexMap: class --> index in 0,...,numClasses-1 + val classIndexMap = { + if (classCounts.keySet != Set(0.0, 1.0)) { + sortedClasses.zipWithIndex.toMap + } else { + Map[Double, Int]() + } + } + val examples = { + if (classIndexMap.isEmpty) { + origExamples + } else { + origExamples.map(lp => LabeledPoint(classIndexMap(lp.label), lp.features)) + } + } + val numExamples = examples.count() + println(s"numClasses = $numClasses.") + println(s"Per-class example fractions, counts:") + println(s"Class\tFrac\tCount") + sortedClasses.foreach { c => + val frac = classCounts(c) / numExamples.toDouble + println(s"$c\t$frac\t${classCounts(c)}") + } + (examples, numClasses) + } + case Regression => + (origExamples, 0) + case _ => + throw new IllegalArgumentException("Algo ${params.algo} not supported.") + } - val splits = examples.randomSplit(Array(0.8, 0.2)) + // Split into training, test. + val splits = examples.randomSplit(Array(1.0 - params.fracTest, params.fracTest)) val training = splits(0).cache() val test = splits(1).cache() - val numTraining = training.count() val numTest = test.count() @@ -129,17 +179,19 @@ object DecisionTreeRunner { impurity = impurityCalculator, maxDepth = params.maxDepth, maxBins = params.maxBins, - numClassesForClassification = params.numClassesForClassification) + numClassesForClassification = numClasses) val model = DecisionTree.train(training, strategy) + println(model) + if (params.algo == Classification) { val accuracy = accuracyScore(model, test) - println(s"Test accuracy = $accuracy.") + println(s"Test accuracy = $accuracy") } if (params.algo == Regression) { val mse = meanSquaredError(model, test) - println(s"Test mean squared error = $mse.") + println(s"Test mean squared error = $mse") } sc.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegression.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegression.scala index 4811bb70e4b28..05b7d66f8dffd 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegression.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegression.scala @@ -91,7 +91,7 @@ object LinearRegression extends App { Logger.getRootLogger.setLevel(Level.WARN) - val examples = MLUtils.loadLibSVMFile(sc, params.input, multiclass = true).cache() + val examples = MLUtils.loadLibSVMFile(sc, params.input).cache() val splits = examples.randomSplit(Array(0.8, 0.2)) val training = splits(0).cache() diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/SparseNaiveBayes.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/SparseNaiveBayes.scala index 537e68a0991aa..88acd9dbb0878 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/SparseNaiveBayes.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/SparseNaiveBayes.scala @@ -22,7 +22,7 @@ import scopt.OptionParser import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.mllib.classification.NaiveBayes -import org.apache.spark.mllib.util.{MLUtils, MulticlassLabelParser} +import org.apache.spark.mllib.util.MLUtils /** * An example naive Bayes app. Run with @@ -76,7 +76,7 @@ object SparseNaiveBayes { if (params.minPartitions > 0) params.minPartitions else sc.defaultMinPartitions val examples = - MLUtils.loadLibSVMFile(sc, params.input, multiclass = true, params.numFeatures, minPartitions) + MLUtils.loadLibSVMFile(sc, params.input, params.numFeatures, minPartitions) // Cache examples because it will be used in both training and evaluation. examples.cache() diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLinearRegression.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLinearRegression.scala new file mode 100644 index 0000000000000..1fd37edfa7427 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLinearRegression.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib + +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.util.MLUtils +import org.apache.spark.mllib.regression.StreamingLinearRegressionWithSGD +import org.apache.spark.SparkConf +import org.apache.spark.streaming.{Seconds, StreamingContext} + +/** + * Train a linear regression model on one stream of data and make predictions + * on another stream, where the data streams arrive as text files + * into two different directories. + * + * The rows of the text files must be labeled data points in the form + * `(y,[x1,x2,x3,...,xn])` + * Where n is the number of features. n must be the same for train and test. + * + * Usage: StreamingLinearRegression + * + * To run on your local machine using the two directories `trainingDir` and `testDir`, + * with updates every 5 seconds, and 2 features per data point, call: + * $ bin/run-example \ + * org.apache.spark.examples.mllib.StreamingLinearRegression trainingDir testDir 5 2 + * + * As you add text files to `trainingDir` the model will continuously update. + * Anytime you add text files to `testDir`, you'll see predictions from the current model. + * + */ +object StreamingLinearRegression { + + def main(args: Array[String]) { + + if (args.length != 4) { + System.err.println( + "Usage: StreamingLinearRegression ") + System.exit(1) + } + + val conf = new SparkConf().setMaster("local").setAppName("StreamingLinearRegression") + val ssc = new StreamingContext(conf, Seconds(args(2).toLong)) + + val trainingData = MLUtils.loadStreamingLabeledPoints(ssc, args(0)) + val testData = MLUtils.loadStreamingLabeledPoints(ssc, args(1)) + + val model = new StreamingLinearRegressionWithSGD() + .setInitialWeights(Vectors.dense(Array.fill[Double](args(3).toInt)(0))) + + model.trainOn(trainingData) + model.predictOn(testData).print() + + ssc.start() + ssc.awaitTermination() + + } + +} diff --git a/examples/src/main/scala/org/apache/spark/examples/pythonconverters/CassandraConverters.scala b/examples/src/main/scala/org/apache/spark/examples/pythonconverters/CassandraConverters.scala index 29a65c7a5f295..83feb5703b908 100644 --- a/examples/src/main/scala/org/apache/spark/examples/pythonconverters/CassandraConverters.scala +++ b/examples/src/main/scala/org/apache/spark/examples/pythonconverters/CassandraConverters.scala @@ -20,7 +20,7 @@ package org.apache.spark.examples.pythonconverters import org.apache.spark.api.python.Converter import java.nio.ByteBuffer import org.apache.cassandra.utils.ByteBufferUtil -import collection.JavaConversions.{mapAsJavaMap, mapAsScalaMap} +import collection.JavaConversions._ /** @@ -44,3 +44,25 @@ class CassandraCQLValueConverter extends Converter[Any, java.util.Map[String, St mapAsJavaMap(result.mapValues(bb => ByteBufferUtil.string(bb))) } } + +/** + * Implementation of [[org.apache.spark.api.python.Converter]] that converts a + * Map[String, Int] to Cassandra key + */ +class ToCassandraCQLKeyConverter extends Converter[Any, java.util.Map[String, ByteBuffer]] { + override def convert(obj: Any): java.util.Map[String, ByteBuffer] = { + val input = obj.asInstanceOf[java.util.Map[String, Int]] + mapAsJavaMap(input.mapValues(i => ByteBufferUtil.bytes(i))) + } +} + +/** + * Implementation of [[org.apache.spark.api.python.Converter]] that converts a + * List[String] to Cassandra value + */ +class ToCassandraCQLValueConverter extends Converter[Any, java.util.List[ByteBuffer]] { + override def convert(obj: Any): java.util.List[ByteBuffer] = { + val input = obj.asInstanceOf[java.util.List[String]] + seqAsJavaList(input.map(s => ByteBufferUtil.bytes(s))) + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/pythonconverters/HBaseConverters.scala b/examples/src/main/scala/org/apache/spark/examples/pythonconverters/HBaseConverters.scala new file mode 100644 index 0000000000000..273bee0a8b30f --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/pythonconverters/HBaseConverters.scala @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.pythonconverters + +import scala.collection.JavaConversions._ + +import org.apache.spark.api.python.Converter +import org.apache.hadoop.hbase.client.{Put, Result} +import org.apache.hadoop.hbase.io.ImmutableBytesWritable +import org.apache.hadoop.hbase.util.Bytes + +/** + * Implementation of [[org.apache.spark.api.python.Converter]] that converts an + * HBase Result to a String + */ +class HBaseResultToStringConverter extends Converter[Any, String] { + override def convert(obj: Any): String = { + val result = obj.asInstanceOf[Result] + Bytes.toStringBinary(result.value()) + } +} + +/** + * Implementation of [[org.apache.spark.api.python.Converter]] that converts an + * ImmutableBytesWritable to a String + */ +class ImmutableBytesWritableToStringConverter extends Converter[Any, String] { + override def convert(obj: Any): String = { + val key = obj.asInstanceOf[ImmutableBytesWritable] + Bytes.toStringBinary(key.get()) + } +} + +/** + * Implementation of [[org.apache.spark.api.python.Converter]] that converts a + * String to an ImmutableBytesWritable + */ +class StringToImmutableBytesWritableConverter extends Converter[Any, ImmutableBytesWritable] { + override def convert(obj: Any): ImmutableBytesWritable = { + val bytes = Bytes.toBytes(obj.asInstanceOf[String]) + new ImmutableBytesWritable(bytes) + } +} + +/** + * Implementation of [[org.apache.spark.api.python.Converter]] that converts a + * list of Strings to HBase Put + */ +class StringListToPutConverter extends Converter[Any, Put] { + override def convert(obj: Any): Put = { + val output = obj.asInstanceOf[java.util.ArrayList[String]].map(Bytes.toBytes(_)).toArray + val put = new Put(output(0)) + put.add(output(1), output(2), output(3)) + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala b/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala index 63db688bfb8c0..d56d64c564200 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala @@ -36,7 +36,7 @@ object RDDRelation { val rdd = sc.parallelize((1 to 100).map(i => Record(i, s"val_$i"))) // Any RDD containing case classes can be registered as a table. The schema of the table is // automatically inferred using scala reflection. - rdd.registerAsTable("records") + rdd.registerTempTable("records") // Once tables have been registered, you can run SQL queries over them. println("Result of SELECT *:") @@ -66,7 +66,7 @@ object RDDRelation { parquetFile.where('key === 1).select('value as 'a).collect().foreach(println) // These files can also be registered as tables. - parquetFile.registerAsTable("parquetFile") + parquetFile.registerTempTable("parquetFile") sql("SELECT * FROM parquetFile").collect().foreach(println) } } diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala b/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala index 66a23fac39999..3423fac0ad303 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala @@ -19,7 +19,7 @@ package org.apache.spark.examples.sql.hive import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.sql._ -import org.apache.spark.sql.hive.LocalHiveContext +import org.apache.spark.sql.hive.HiveContext object HiveFromSpark { case class Record(key: Int, value: String) @@ -31,23 +31,23 @@ object HiveFromSpark { // A local hive context creates an instance of the Hive Metastore in process, storing the // the warehouse data in the current directory. This location can be overridden by // specifying a second parameter to the constructor. - val hiveContext = new LocalHiveContext(sc) + val hiveContext = new HiveContext(sc) import hiveContext._ - hql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") - hql("LOAD DATA LOCAL INPATH 'src/main/resources/kv1.txt' INTO TABLE src") + sql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") + sql("LOAD DATA LOCAL INPATH 'src/main/resources/kv1.txt' INTO TABLE src") // Queries are expressed in HiveQL println("Result of 'SELECT *': ") - hql("SELECT * FROM src").collect.foreach(println) + sql("SELECT * FROM src").collect.foreach(println) // Aggregation queries are also supported. - val count = hql("SELECT COUNT(*) FROM src").collect().head.getLong(0) + val count = sql("SELECT COUNT(*) FROM src").collect().head.getLong(0) println(s"COUNT(*): $count") // The results of SQL queries are themselves RDDs and support all normal RDD functions. The // items in the RDD are of type Row, which allows you to access each column by ordinal. - val rddFromSql = hql("SELECT key, value FROM src WHERE key < 10 ORDER BY key") + val rddFromSql = sql("SELECT key, value FROM src WHERE key < 10 ORDER BY key") println("Result of RDD.map:") val rddAsStrings = rddFromSql.map { @@ -56,10 +56,10 @@ object HiveFromSpark { // You can also register RDDs as temporary tables within a HiveContext. val rdd = sc.parallelize((1 to 100).map(i => Record(i, s"val_$i"))) - rdd.registerAsTable("records") + rdd.registerTempTable("records") // Queries can then join RDD data with data stored in Hive. println("Result of SELECT *:") - hql("SELECT * FROM records r JOIN src s ON r.key = s.key").collect().foreach(println) + sql("SELECT * FROM records r JOIN src s ON r.key = s.key").collect().foreach(println) } } diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/FlumePollingEventCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/FlumePollingEventCount.scala new file mode 100644 index 0000000000000..1cc8c8d5c23b6 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/FlumePollingEventCount.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.streaming + +import org.apache.spark.SparkConf +import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming._ +import org.apache.spark.streaming.flume._ +import org.apache.spark.util.IntParam +import java.net.InetSocketAddress + +/** + * Produces a count of events received from Flume. + * + * This should be used in conjunction with the Spark Sink running in a Flume agent. See + * the Spark Streaming programming guide for more details. + * + * Usage: FlumePollingEventCount + * `host` is the host on which the Spark Sink is running. + * `port` is the port at which the Spark Sink is listening. + * + * To run this example: + * `$ bin/run-example org.apache.spark.examples.streaming.FlumePollingEventCount [host] [port] ` + */ +object FlumePollingEventCount { + def main(args: Array[String]) { + if (args.length < 2) { + System.err.println( + "Usage: FlumePollingEventCount ") + System.exit(1) + } + + StreamingExamples.setStreamingLogLevels() + + val Array(host, IntParam(port)) = args + + val batchInterval = Milliseconds(2000) + + // Create the context and set the batch size + val sparkConf = new SparkConf().setAppName("FlumePollingEventCount") + val ssc = new StreamingContext(sparkConf, batchInterval) + + // Create a flume stream that polls the Spark Sink running in a Flume agent + val stream = FlumeUtils.createPollingStream(ssc, host, port) + + // Print out the count of events received from this server in each batch + stream.count().map(cnt => "Received " + cnt + " flume events." ).print() + + ssc.start() + ssc.awaitTermination() + } +} diff --git a/external/flume-sink/pom.xml b/external/flume-sink/pom.xml new file mode 100644 index 0000000000000..d0bf1cf1ea796 --- /dev/null +++ b/external/flume-sink/pom.xml @@ -0,0 +1,103 @@ + + + + + 4.0.0 + + org.apache.spark + spark-parent + 1.1.0-SNAPSHOT + ../../pom.xml + + + spark-streaming-flume-sink_2.10 + + streaming-flume-sink + + + jar + Spark Project External Flume Sink + http://spark.apache.org/ + + + org.apache.flume + flume-ng-sdk + 1.4.0 + + + io.netty + netty + + + org.apache.thrift + libthrift + + + + + org.apache.flume + flume-ng-core + 1.4.0 + + + io.netty + netty + + + org.apache.thrift + libthrift + + + + + org.scala-lang + scala-library + + + org.scalatest + scalatest_${scala.binary.version} + + + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + + org.scalatest + scalatest-maven-plugin + + + org.apache.avro + avro-maven-plugin + 1.7.3 + + + ${project.basedir}/target/scala-${scala.binary.version}/src_managed/main/compiled_avro + + + + generate-sources + + idl-protocol + + + + + + + diff --git a/external/flume-sink/src/main/avro/sparkflume.avdl b/external/flume-sink/src/main/avro/sparkflume.avdl new file mode 100644 index 0000000000000..8806e863ac7c6 --- /dev/null +++ b/external/flume-sink/src/main/avro/sparkflume.avdl @@ -0,0 +1,40 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +@namespace("org.apache.spark.streaming.flume.sink") + +protocol SparkFlumeProtocol { + + record SparkSinkEvent { + map headers; + bytes body; + } + + record EventBatch { + string errorMsg = ""; // If this is empty it is a valid message, else it represents an error + string sequenceNumber; + array events; + } + + EventBatch getEventBatch (int n); + + void ack (string sequenceNumber); + + void nack (string sequenceNumber); +} diff --git a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/Logging.scala b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/Logging.scala new file mode 100644 index 0000000000000..17cbc6707b5ea --- /dev/null +++ b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/Logging.scala @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.streaming.flume.sink + +import org.slf4j.{Logger, LoggerFactory} + +/** + * Copy of the org.apache.spark.Logging for being used in the Spark Sink. + * The org.apache.spark.Logging is not used so that all of Spark is not brought + * in as a dependency. + */ +private[sink] trait Logging { + // Make the log field transient so that objects with Logging can + // be serialized and used on another machine + @transient private var log_ : Logger = null + + // Method to get or create the logger for this object + protected def log: Logger = { + if (log_ == null) { + initializeIfNecessary() + var className = this.getClass.getName + // Ignore trailing $'s in the class names for Scala objects + if (className.endsWith("$")) { + className = className.substring(0, className.length - 1) + } + log_ = LoggerFactory.getLogger(className) + } + log_ + } + + // Log methods that take only a String + protected def logInfo(msg: => String) { + if (log.isInfoEnabled) log.info(msg) + } + + protected def logDebug(msg: => String) { + if (log.isDebugEnabled) log.debug(msg) + } + + protected def logTrace(msg: => String) { + if (log.isTraceEnabled) log.trace(msg) + } + + protected def logWarning(msg: => String) { + if (log.isWarnEnabled) log.warn(msg) + } + + protected def logError(msg: => String) { + if (log.isErrorEnabled) log.error(msg) + } + + // Log methods that take Throwables (Exceptions/Errors) too + protected def logInfo(msg: => String, throwable: Throwable) { + if (log.isInfoEnabled) log.info(msg, throwable) + } + + protected def logDebug(msg: => String, throwable: Throwable) { + if (log.isDebugEnabled) log.debug(msg, throwable) + } + + protected def logTrace(msg: => String, throwable: Throwable) { + if (log.isTraceEnabled) log.trace(msg, throwable) + } + + protected def logWarning(msg: => String, throwable: Throwable) { + if (log.isWarnEnabled) log.warn(msg, throwable) + } + + protected def logError(msg: => String, throwable: Throwable) { + if (log.isErrorEnabled) log.error(msg, throwable) + } + + protected def isTraceEnabled(): Boolean = { + log.isTraceEnabled + } + + private def initializeIfNecessary() { + if (!Logging.initialized) { + Logging.initLock.synchronized { + if (!Logging.initialized) { + initializeLogging() + } + } + } + } + + private def initializeLogging() { + Logging.initialized = true + + // Force a call into slf4j to initialize it. Avoids this happening from mutliple threads + // and triggering this: http://mailman.qos.ch/pipermail/slf4j-dev/2010-April/002956.html + log + } +} + +private[sink] object Logging { + @volatile private var initialized = false + val initLock = new Object() + try { + // We use reflection here to handle the case where users remove the + // slf4j-to-jul bridge order to route their logs to JUL. + val bridgeClass = Class.forName("org.slf4j.bridge.SLF4JBridgeHandler") + bridgeClass.getMethod("removeHandlersForRootLogger").invoke(null) + val installed = bridgeClass.getMethod("isInstalled").invoke(null).asInstanceOf[Boolean] + if (!installed) { + bridgeClass.getMethod("install").invoke(null) + } + } catch { + case e: ClassNotFoundException => // can't log anything yet so just fail silently + } +} diff --git a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala new file mode 100644 index 0000000000000..7da8eb3e35912 --- /dev/null +++ b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.streaming.flume.sink + +import java.util.concurrent.{ConcurrentHashMap, Executors} +import java.util.concurrent.atomic.AtomicLong + +import org.apache.flume.Channel +import org.apache.commons.lang.RandomStringUtils +import com.google.common.util.concurrent.ThreadFactoryBuilder + +/** + * Class that implements the SparkFlumeProtocol, that is used by the Avro Netty Server to process + * requests. Each getEvents, ack and nack call is forwarded to an instance of this class. + * @param threads Number of threads to use to process requests. + * @param channel The channel that the sink pulls events from + * @param transactionTimeout Timeout in millis after which the transaction if not acked by Spark + * is rolled back. + */ +// Flume forces transactions to be thread-local. So each transaction *must* be committed, or +// rolled back from the thread it was originally created in. So each getEvents call from Spark +// creates a TransactionProcessor which runs in a new thread, in which the transaction is created +// and events are pulled off the channel. Once the events are sent to spark, +// that thread is blocked and the TransactionProcessor is saved in a map, +// until an ACK or NACK comes back or the transaction times out (after the specified timeout). +// When the response comes or a timeout is hit, the TransactionProcessor is retrieved and then +// unblocked, at which point the transaction is committed or rolled back. + +private[flume] class SparkAvroCallbackHandler(val threads: Int, val channel: Channel, + val transactionTimeout: Int, val backOffInterval: Int) extends SparkFlumeProtocol with Logging { + val transactionExecutorOpt = Option(Executors.newFixedThreadPool(threads, + new ThreadFactoryBuilder().setDaemon(true) + .setNameFormat("Spark Sink Processor Thread - %d").build())) + private val processorMap = new ConcurrentHashMap[CharSequence, TransactionProcessor]() + // This sink will not persist sequence numbers and reuses them if it gets restarted. + // So it is possible to commit a transaction which may have been meant for the sink before the + // restart. + // Since the new txn may not have the same sequence number we must guard against accidentally + // committing a new transaction. To reduce the probability of that happening a random string is + // prepended to the sequence number. Does not change for life of sink + private val seqBase = RandomStringUtils.randomAlphanumeric(8) + private val seqCounter = new AtomicLong(0) + + /** + * Returns a bunch of events to Spark over Avro RPC. + * @param n Maximum number of events to return in a batch + * @return [[EventBatch]] instance that has a sequence number and an array of at most n events + */ + override def getEventBatch(n: Int): EventBatch = { + logDebug("Got getEventBatch call from Spark.") + val sequenceNumber = seqBase + seqCounter.incrementAndGet() + val processor = new TransactionProcessor(channel, sequenceNumber, + n, transactionTimeout, backOffInterval, this) + transactionExecutorOpt.foreach(executor => { + executor.submit(processor) + }) + // Wait until a batch is available - will be an error if error message is non-empty + val batch = processor.getEventBatch + if (!SparkSinkUtils.isErrorBatch(batch)) { + processorMap.put(sequenceNumber.toString, processor) + logDebug("Sending event batch with sequence number: " + sequenceNumber) + } + batch + } + + /** + * Called by Spark to indicate successful commit of a batch + * @param sequenceNumber The sequence number of the event batch that was successful + */ + override def ack(sequenceNumber: CharSequence): Void = { + logDebug("Received Ack for batch with sequence number: " + sequenceNumber) + completeTransaction(sequenceNumber, success = true) + null + } + + /** + * Called by Spark to indicate failed commit of a batch + * @param sequenceNumber The sequence number of the event batch that failed + * @return + */ + override def nack(sequenceNumber: CharSequence): Void = { + completeTransaction(sequenceNumber, success = false) + logInfo("Spark failed to commit transaction. Will reattempt events.") + null + } + + /** + * Helper method to commit or rollback a transaction. + * @param sequenceNumber The sequence number of the batch that was completed + * @param success Whether the batch was successful or not. + */ + private def completeTransaction(sequenceNumber: CharSequence, success: Boolean) { + Option(removeAndGetProcessor(sequenceNumber)).foreach(processor => { + processor.batchProcessed(success) + }) + } + + /** + * Helper method to remove the TxnProcessor for a Sequence Number. Can be used to avoid a leak. + * @param sequenceNumber + * @return The transaction processor for the corresponding batch. Note that this instance is no + * longer tracked and the caller is responsible for that txn processor. + */ + private[sink] def removeAndGetProcessor(sequenceNumber: CharSequence): TransactionProcessor = { + processorMap.remove(sequenceNumber.toString) // The toString is required! + } + + /** + * Shuts down the executor used to process transactions. + */ + def shutdown() { + logInfo("Shutting down Spark Avro Callback Handler") + transactionExecutorOpt.foreach(executor => { + executor.shutdownNow() + }) + } +} diff --git a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSink.scala b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSink.scala new file mode 100644 index 0000000000000..7b735133e3d14 --- /dev/null +++ b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSink.scala @@ -0,0 +1,154 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.streaming.flume.sink + +import java.net.InetSocketAddress +import java.util.concurrent._ + +import org.apache.avro.ipc.NettyServer +import org.apache.avro.ipc.specific.SpecificResponder +import org.apache.flume.Context +import org.apache.flume.Sink.Status +import org.apache.flume.conf.{Configurable, ConfigurationException} +import org.apache.flume.sink.AbstractSink + +/** + * A sink that uses Avro RPC to run a server that can be polled by Spark's + * FlumePollingInputDStream. This sink has the following configuration parameters: + * + * hostname - The hostname to bind to. Default: 0.0.0.0 + * port - The port to bind to. (No default - mandatory) + * timeout - Time in seconds after which a transaction is rolled back, + * if an ACK is not received from Spark within that time + * threads - Number of threads to use to receive requests from Spark (Default: 10) + * + * This sink is unlike other Flume sinks in the sense that it does not push data, + * instead the process method in this sink simply blocks the SinkRunner the first time it is + * called. This sink starts up an Avro IPC server that uses the SparkFlumeProtocol. + * + * Each time a getEventBatch call comes, creates a transaction and reads events + * from the channel. When enough events are read, the events are sent to the Spark receiver and + * the thread itself is blocked and a reference to it saved off. + * + * When the ack for that batch is received, + * the thread which created the transaction is is retrieved and it commits the transaction with the + * channel from the same thread it was originally created in (since Flume transactions are + * thread local). If a nack is received instead, the sink rolls back the transaction. If no ack + * is received within the specified timeout, the transaction is rolled back too. If an ack comes + * after that, it is simply ignored and the events get re-sent. + * + */ + +private[flume] +class SparkSink extends AbstractSink with Logging with Configurable { + + // Size of the pool to use for holding transaction processors. + private var poolSize: Integer = SparkSinkConfig.DEFAULT_THREADS + + // Timeout for each transaction. If spark does not respond in this much time, + // rollback the transaction + private var transactionTimeout = SparkSinkConfig.DEFAULT_TRANSACTION_TIMEOUT + + // Address info to bind on + private var hostname: String = SparkSinkConfig.DEFAULT_HOSTNAME + private var port: Int = 0 + + private var backOffInterval: Int = 200 + + // Handle to the server + private var serverOpt: Option[NettyServer] = None + + // The handler that handles the callback from Avro + private var handler: Option[SparkAvroCallbackHandler] = None + + // Latch that blocks off the Flume framework from wasting 1 thread. + private val blockingLatch = new CountDownLatch(1) + + override def start() { + logInfo("Starting Spark Sink: " + getName + " on port: " + port + " and interface: " + + hostname + " with " + "pool size: " + poolSize + " and transaction timeout: " + + transactionTimeout + ".") + handler = Option(new SparkAvroCallbackHandler(poolSize, getChannel, transactionTimeout, + backOffInterval)) + val responder = new SpecificResponder(classOf[SparkFlumeProtocol], handler.get) + // Using the constructor that takes specific thread-pools requires bringing in netty + // dependencies which are being excluded in the build. In practice, + // Netty dependencies are already available on the JVM as Flume would have pulled them in. + serverOpt = Option(new NettyServer(responder, new InetSocketAddress(hostname, port))) + serverOpt.foreach(server => { + logInfo("Starting Avro server for sink: " + getName) + server.start() + }) + super.start() + } + + override def stop() { + logInfo("Stopping Spark Sink: " + getName) + handler.foreach(callbackHandler => { + callbackHandler.shutdown() + }) + serverOpt.foreach(server => { + logInfo("Stopping Avro Server for sink: " + getName) + server.close() + server.join() + }) + blockingLatch.countDown() + super.stop() + } + + override def configure(ctx: Context) { + import SparkSinkConfig._ + hostname = ctx.getString(CONF_HOSTNAME, DEFAULT_HOSTNAME) + port = Option(ctx.getInteger(CONF_PORT)). + getOrElse(throw new ConfigurationException("The port to bind to must be specified")) + poolSize = ctx.getInteger(THREADS, DEFAULT_THREADS) + transactionTimeout = ctx.getInteger(CONF_TRANSACTION_TIMEOUT, DEFAULT_TRANSACTION_TIMEOUT) + backOffInterval = ctx.getInteger(CONF_BACKOFF_INTERVAL, DEFAULT_BACKOFF_INTERVAL) + logInfo("Configured Spark Sink with hostname: " + hostname + ", port: " + port + ", " + + "poolSize: " + poolSize + ", transactionTimeout: " + transactionTimeout + ", " + + "backoffInterval: " + backOffInterval) + } + + override def process(): Status = { + // This method is called in a loop by the Flume framework - block it until the sink is + // stopped to save CPU resources. The sink runner will interrupt this thread when the sink is + // being shut down. + logInfo("Blocking Sink Runner, sink will continue to run..") + blockingLatch.await() + Status.BACKOFF + } +} + +/** + * Configuration parameters and their defaults. + */ +private[flume] +object SparkSinkConfig { + val THREADS = "threads" + val DEFAULT_THREADS = 10 + + val CONF_TRANSACTION_TIMEOUT = "timeout" + val DEFAULT_TRANSACTION_TIMEOUT = 60 + + val CONF_HOSTNAME = "hostname" + val DEFAULT_HOSTNAME = "0.0.0.0" + + val CONF_PORT = "port" + + val CONF_BACKOFF_INTERVAL = "backoffInterval" + val DEFAULT_BACKOFF_INTERVAL = 200 +} diff --git a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSinkUtils.scala b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSinkUtils.scala new file mode 100644 index 0000000000000..47c0e294d6b52 --- /dev/null +++ b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSinkUtils.scala @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.streaming.flume.sink + +private[flume] object SparkSinkUtils { + /** + * This method determines if this batch represents an error or not. + * @param batch - The batch to check + * @return - true if the batch represents an error + */ + def isErrorBatch(batch: EventBatch): Boolean = { + !batch.getErrorMsg.toString.equals("") // If there is an error message, it is an error batch. + } +} diff --git a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/TransactionProcessor.scala b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/TransactionProcessor.scala new file mode 100644 index 0000000000000..b9e3c786ebb3b --- /dev/null +++ b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/TransactionProcessor.scala @@ -0,0 +1,228 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.streaming.flume.sink + +import java.nio.ByteBuffer +import java.util +import java.util.concurrent.{Callable, CountDownLatch, TimeUnit} + +import scala.util.control.Breaks + +import org.apache.flume.{Transaction, Channel} + +// Flume forces transactions to be thread-local (horrible, I know!) +// So the sink basically spawns a new thread to pull the events out within a transaction. +// The thread fills in the event batch object that is set before the thread is scheduled. +// After filling it in, the thread waits on a condition - which is released only +// when the success message comes back for the specific sequence number for that event batch. +/** + * This class represents a transaction on the Flume channel. This class runs a separate thread + * which owns the transaction. The thread is blocked until the success call for that transaction + * comes back with an ACK or NACK. + * @param channel The channel from which to pull events + * @param seqNum The sequence number to use for the transaction. Must be unique + * @param maxBatchSize The maximum number of events to process per batch + * @param transactionTimeout Time in seconds after which a transaction must be rolled back + * without waiting for an ACK from Spark + * @param parent The parent [[SparkAvroCallbackHandler]] instance, for reporting timeouts + */ +private class TransactionProcessor(val channel: Channel, val seqNum: String, + var maxBatchSize: Int, val transactionTimeout: Int, val backOffInterval: Int, + val parent: SparkAvroCallbackHandler) extends Callable[Void] with Logging { + + // If a real batch is not returned, we always have to return an error batch. + @volatile private var eventBatch: EventBatch = new EventBatch("Unknown Error", "", + util.Collections.emptyList()) + + // Synchronization primitives + val batchGeneratedLatch = new CountDownLatch(1) + val batchAckLatch = new CountDownLatch(1) + + // Sanity check to ensure we don't loop like crazy + val totalAttemptsToRemoveFromChannel = Int.MaxValue / 2 + + // OK to use volatile, since the change would only make this true (otherwise it will be + // changed to false - we never apply a negation operation to this) - which means the transaction + // succeeded. + @volatile private var batchSuccess = false + + // The transaction that this processor would handle + var txOpt: Option[Transaction] = None + + /** + * Get an event batch from the channel. This method will block until a batch of events is + * available from the channel. If no events are available after a large number of attempts of + * polling the channel, this method will return an [[EventBatch]] with a non-empty error message + * + * @return An [[EventBatch]] instance with sequence number set to seqNum, filled with a + * maximum of maxBatchSize events + */ + def getEventBatch: EventBatch = { + batchGeneratedLatch.await() + eventBatch + } + + /** + * This method is to be called by the sink when it receives an ACK or NACK from Spark. This + * method is a no-op if it is called after transactionTimeout has expired since + * getEventBatch returned a batch of events. + * @param success True if an ACK was received and the transaction should be committed, else false. + */ + def batchProcessed(success: Boolean) { + logDebug("Batch processed for sequence number: " + seqNum) + batchSuccess = success + batchAckLatch.countDown() + } + + /** + * Populates events into the event batch. If the batch cannot be populated, + * this method will not set the events into the event batch, but it sets an error message. + */ + private def populateEvents() { + try { + txOpt = Option(channel.getTransaction) + if(txOpt.isEmpty) { + eventBatch.setErrorMsg("Something went wrong. Channel was " + + "unable to create a transaction!") + } + txOpt.foreach(tx => { + tx.begin() + val events = new util.ArrayList[SparkSinkEvent](maxBatchSize) + val loop = new Breaks + var gotEventsInThisTxn = false + var loopCounter: Int = 0 + loop.breakable { + while (events.size() < maxBatchSize + && loopCounter < totalAttemptsToRemoveFromChannel) { + loopCounter += 1 + Option(channel.take()) match { + case Some(event) => + events.add(new SparkSinkEvent(toCharSequenceMap(event.getHeaders), + ByteBuffer.wrap(event.getBody))) + gotEventsInThisTxn = true + case None => + if (!gotEventsInThisTxn) { + logDebug("Sleeping for " + backOffInterval + " millis as no events were read in" + + " the current transaction") + TimeUnit.MILLISECONDS.sleep(backOffInterval) + } else { + loop.break() + } + } + } + } + if (!gotEventsInThisTxn) { + val msg = "Tried several times, " + + "but did not get any events from the channel!" + logWarning(msg) + eventBatch.setErrorMsg(msg) + } else { + // At this point, the events are available, so fill them into the event batch + eventBatch = new EventBatch("",seqNum, events) + } + }) + } catch { + case e: Exception => + logWarning("Error while processing transaction.", e) + eventBatch.setErrorMsg(e.getMessage) + try { + txOpt.foreach(tx => { + rollbackAndClose(tx, close = true) + }) + } finally { + txOpt = None + } + } finally { + batchGeneratedLatch.countDown() + } + } + + /** + * Waits for upto transactionTimeout seconds for an ACK. If an ACK comes in + * this method commits the transaction with the channel. If the ACK does not come in within + * that time or a NACK comes in, this method rolls back the transaction. + */ + private def processAckOrNack() { + batchAckLatch.await(transactionTimeout, TimeUnit.SECONDS) + txOpt.foreach(tx => { + if (batchSuccess) { + try { + logDebug("Committing transaction") + tx.commit() + } catch { + case e: Exception => + logWarning("Error while attempting to commit transaction. Transaction will be rolled " + + "back", e) + rollbackAndClose(tx, close = false) // tx will be closed later anyway + } finally { + tx.close() + } + } else { + logWarning("Spark could not commit transaction, NACK received. Rolling back transaction.") + rollbackAndClose(tx, close = true) + // This might have been due to timeout or a NACK. Either way the following call does not + // cause issues. This is required to ensure the TransactionProcessor instance is not leaked + parent.removeAndGetProcessor(seqNum) + } + }) + } + + /** + * Helper method to rollback and optionally close a transaction + * @param tx The transaction to rollback + * @param close Whether the transaction should be closed or not after rolling back + */ + private def rollbackAndClose(tx: Transaction, close: Boolean) { + try { + logWarning("Spark was unable to successfully process the events. Transaction is being " + + "rolled back.") + tx.rollback() + } catch { + case e: Exception => + logError("Error rolling back transaction. Rollback may have failed!", e) + } finally { + if (close) { + tx.close() + } + } + } + + /** + * Helper method to convert a Map[String, String] to Map[CharSequence, CharSequence] + * @param inMap The map to be converted + * @return The converted map + */ + private def toCharSequenceMap(inMap: java.util.Map[String, String]): java.util.Map[CharSequence, + CharSequence] = { + val charSeqMap = new util.HashMap[CharSequence, CharSequence](inMap.size()) + charSeqMap.putAll(inMap) + charSeqMap + } + + /** + * When the thread is started it sets as many events as the batch size or less (if enough + * events aren't available) into the eventBatch and object and lets any threads waiting on the + * [[getEventBatch]] method to proceed. Then this thread waits for acks or nacks to come in, + * or for a specified timeout and commits or rolls back the transaction. + * @return + */ + override def call(): Void = { + populateEvents() + processAckOrNack() + null + } +} diff --git a/external/flume/pom.xml b/external/flume/pom.xml index 61a6aff543aed..c532705f3950c 100644 --- a/external/flume/pom.xml +++ b/external/flume/pom.xml @@ -28,7 +28,7 @@ org.apache.spark spark-streaming-flume_2.10 - streaming-flume + streaming-flume jar Spark Project External Flume @@ -72,11 +72,21 @@ scalacheck_${scala.binary.version} test + + junit + junit + test + com.novocode junit-interface test + + org.apache.spark + spark-streaming-flume-sink_2.10 + ${project.version} + target/scala-${scala.binary.version}/classes diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/EventTransformer.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/EventTransformer.scala new file mode 100644 index 0000000000000..dc629df4f4ac2 --- /dev/null +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/EventTransformer.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.flume + +import java.io.{ObjectOutput, ObjectInput} + +import scala.collection.JavaConversions._ + +import org.apache.spark.util.Utils +import org.apache.spark.Logging + +/** + * A simple object that provides the implementation of readExternal and writeExternal for both + * the wrapper classes for Flume-style Events. + */ +private[streaming] object EventTransformer extends Logging { + def readExternal(in: ObjectInput): (java.util.HashMap[CharSequence, CharSequence], + Array[Byte]) = { + val bodyLength = in.readInt() + val bodyBuff = new Array[Byte](bodyLength) + in.readFully(bodyBuff) + + val numHeaders = in.readInt() + val headers = new java.util.HashMap[CharSequence, CharSequence] + + for (i <- 0 until numHeaders) { + val keyLength = in.readInt() + val keyBuff = new Array[Byte](keyLength) + in.readFully(keyBuff) + val key: String = Utils.deserialize(keyBuff) + + val valLength = in.readInt() + val valBuff = new Array[Byte](valLength) + in.readFully(valBuff) + val value: String = Utils.deserialize(valBuff) + + headers.put(key, value) + } + (headers, bodyBuff) + } + + def writeExternal(out: ObjectOutput, headers: java.util.Map[CharSequence, CharSequence], + body: Array[Byte]) { + out.writeInt(body.length) + out.write(body) + val numHeaders = headers.size() + out.writeInt(numHeaders) + for ((k,v) <- headers) { + val keyBuff = Utils.serialize(k.toString) + out.writeInt(keyBuff.length) + out.write(keyBuff) + val valBuff = Utils.serialize(v.toString) + out.writeInt(valBuff.length) + out.write(valBuff) + } + } +} diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala index 56d2886b26878..4b2ea45fb81d0 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala @@ -39,11 +39,8 @@ import org.apache.spark.streaming.receiver.Receiver import org.jboss.netty.channel.ChannelPipelineFactory import org.jboss.netty.channel.Channels -import org.jboss.netty.channel.ChannelPipeline -import org.jboss.netty.channel.ChannelFactory import org.jboss.netty.channel.socket.nio.NioServerSocketChannelFactory import org.jboss.netty.handler.codec.compression._ -import org.jboss.netty.handler.execution.ExecutionHandler private[streaming] class FlumeInputDStream[T: ClassTag]( diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala new file mode 100644 index 0000000000000..148262bb6771e --- /dev/null +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala @@ -0,0 +1,178 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.streaming.flume + + +import java.net.InetSocketAddress +import java.util.concurrent.{LinkedBlockingQueue, TimeUnit, Executors} + +import scala.collection.JavaConversions._ +import scala.collection.mutable.ArrayBuffer +import scala.reflect.ClassTag + +import com.google.common.util.concurrent.ThreadFactoryBuilder +import org.apache.avro.ipc.NettyTransceiver +import org.apache.avro.ipc.specific.SpecificRequestor +import org.jboss.netty.channel.socket.nio.NioClientSocketChannelFactory + +import org.apache.spark.Logging +import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming.StreamingContext +import org.apache.spark.streaming.dstream.ReceiverInputDStream +import org.apache.spark.streaming.receiver.Receiver +import org.apache.spark.streaming.flume.sink._ + +/** + * A [[ReceiverInputDStream]] that can be used to read data from several Flume agents running + * [[org.apache.spark.streaming.flume.sink.SparkSink]]s. + * @param _ssc Streaming context that will execute this input stream + * @param addresses List of addresses at which SparkSinks are listening + * @param maxBatchSize Maximum size of a batch + * @param parallelism Number of parallel connections to open + * @param storageLevel The storage level to use. + * @tparam T Class type of the object of this stream + */ +private[streaming] class FlumePollingInputDStream[T: ClassTag]( + @transient _ssc: StreamingContext, + val addresses: Seq[InetSocketAddress], + val maxBatchSize: Int, + val parallelism: Int, + storageLevel: StorageLevel + ) extends ReceiverInputDStream[SparkFlumeEvent](_ssc) { + + override def getReceiver(): Receiver[SparkFlumeEvent] = { + new FlumePollingReceiver(addresses, maxBatchSize, parallelism, storageLevel) + } +} + +private[streaming] class FlumePollingReceiver( + addresses: Seq[InetSocketAddress], + maxBatchSize: Int, + parallelism: Int, + storageLevel: StorageLevel + ) extends Receiver[SparkFlumeEvent](storageLevel) with Logging { + + lazy val channelFactoryExecutor = + Executors.newCachedThreadPool(new ThreadFactoryBuilder().setDaemon(true). + setNameFormat("Flume Receiver Channel Thread - %d").build()) + + lazy val channelFactory = + new NioClientSocketChannelFactory(channelFactoryExecutor, channelFactoryExecutor) + + lazy val receiverExecutor = Executors.newFixedThreadPool(parallelism, + new ThreadFactoryBuilder().setDaemon(true).setNameFormat("Flume Receiver Thread - %d").build()) + + private lazy val connections = new LinkedBlockingQueue[FlumeConnection]() + + override def onStart(): Unit = { + // Create the connections to each Flume agent. + addresses.foreach(host => { + val transceiver = new NettyTransceiver(host, channelFactory) + val client = SpecificRequestor.getClient(classOf[SparkFlumeProtocol.Callback], transceiver) + connections.add(new FlumeConnection(transceiver, client)) + }) + for (i <- 0 until parallelism) { + logInfo("Starting Flume Polling Receiver worker threads starting..") + // Threads that pull data from Flume. + receiverExecutor.submit(new Runnable { + override def run(): Unit = { + while (true) { + val connection = connections.poll() + val client = connection.client + try { + val eventBatch = client.getEventBatch(maxBatchSize) + if (!SparkSinkUtils.isErrorBatch(eventBatch)) { + // No error, proceed with processing data + val seq = eventBatch.getSequenceNumber + val events: java.util.List[SparkSinkEvent] = eventBatch.getEvents + logDebug( + "Received batch of " + events.size() + " events with sequence number: " + seq) + try { + // Convert each Flume event to a serializable SparkFlumeEvent + val buffer = new ArrayBuffer[SparkFlumeEvent](events.size()) + var j = 0 + while (j < events.size()) { + buffer += toSparkFlumeEvent(events(j)) + j += 1 + } + store(buffer) + logDebug("Sending ack for sequence number: " + seq) + // Send an ack to Flume so that Flume discards the events from its channels. + client.ack(seq) + logDebug("Ack sent for sequence number: " + seq) + } catch { + case e: Exception => + try { + // Let Flume know that the events need to be pushed back into the channel. + logDebug("Sending nack for sequence number: " + seq) + client.nack(seq) // If the agent is down, even this could fail and throw + logDebug("Nack sent for sequence number: " + seq) + } catch { + case e: Exception => logError( + "Sending Nack also failed. A Flume agent is down.") + } + TimeUnit.SECONDS.sleep(2L) // for now just leave this as a fixed 2 seconds. + logWarning("Error while attempting to store events", e) + } + } else { + logWarning("Did not receive events from Flume agent due to error on the Flume " + + "agent: " + eventBatch.getErrorMsg) + } + } catch { + case e: Exception => + logWarning("Error while reading data from Flume", e) + } finally { + connections.add(connection) + } + } + } + }) + } + } + + override def onStop(): Unit = { + logInfo("Shutting down Flume Polling Receiver") + receiverExecutor.shutdownNow() + connections.foreach(connection => { + connection.transceiver.close() + }) + channelFactory.releaseExternalResources() + } + + /** + * Utility method to convert [[SparkSinkEvent]] to [[SparkFlumeEvent]] + * @param event - Event to convert to SparkFlumeEvent + * @return - The SparkFlumeEvent generated from SparkSinkEvent + */ + private def toSparkFlumeEvent(event: SparkSinkEvent): SparkFlumeEvent = { + val sparkFlumeEvent = new SparkFlumeEvent() + sparkFlumeEvent.event.setBody(event.getBody) + sparkFlumeEvent.event.setHeaders(event.getHeaders) + sparkFlumeEvent + } +} + +/** + * A wrapper around the transceiver and the Avro IPC API. + * @param transceiver The transceiver to use for communication with Flume + * @param client The client that the callbacks are received on. + */ +private class FlumeConnection(val transceiver: NettyTransceiver, + val client: SparkFlumeProtocol.Callback) + + + diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala index 716db9fa76031..4b732c1592ab2 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala @@ -17,12 +17,19 @@ package org.apache.spark.streaming.flume +import java.net.InetSocketAddress + +import org.apache.spark.annotation.Experimental import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.StreamingContext -import org.apache.spark.streaming.api.java.{JavaReceiverInputDStream, JavaInputDStream, JavaStreamingContext, JavaDStream} -import org.apache.spark.streaming.dstream.{ReceiverInputDStream, DStream} +import org.apache.spark.streaming.api.java.{JavaReceiverInputDStream, JavaStreamingContext} +import org.apache.spark.streaming.dstream.ReceiverInputDStream + object FlumeUtils { + private val DEFAULT_POLLING_PARALLELISM = 5 + private val DEFAULT_POLLING_BATCH_SIZE = 1000 + /** * Create a input stream from a Flume source. * @param ssc StreamingContext object @@ -56,7 +63,7 @@ object FlumeUtils { ): ReceiverInputDStream[SparkFlumeEvent] = { val inputStream = new FlumeInputDStream[SparkFlumeEvent]( ssc, hostname, port, storageLevel, enableDecompression) - + inputStream } @@ -105,4 +112,135 @@ object FlumeUtils { ): JavaReceiverInputDStream[SparkFlumeEvent] = { createStream(jssc.ssc, hostname, port, storageLevel, enableDecompression) } + + /** + * Creates an input stream that is to be used with the Spark Sink deployed on a Flume agent. + * This stream will poll the sink for data and will pull events as they are available. + * This stream will use a batch size of 1000 events and run 5 threads to pull data. + * @param hostname Address of the host on which the Spark Sink is running + * @param port Port of the host at which the Spark Sink is listening + * @param storageLevel Storage level to use for storing the received objects + */ + @Experimental + def createPollingStream( + ssc: StreamingContext, + hostname: String, + port: Int, + storageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK_SER_2 + ): ReceiverInputDStream[SparkFlumeEvent] = { + createPollingStream(ssc, Seq(new InetSocketAddress(hostname, port)), storageLevel) + } + + /** + * Creates an input stream that is to be used with the Spark Sink deployed on a Flume agent. + * This stream will poll the sink for data and will pull events as they are available. + * This stream will use a batch size of 1000 events and run 5 threads to pull data. + * @param addresses List of InetSocketAddresses representing the hosts to connect to. + * @param storageLevel Storage level to use for storing the received objects + */ + @Experimental + def createPollingStream( + ssc: StreamingContext, + addresses: Seq[InetSocketAddress], + storageLevel: StorageLevel + ): ReceiverInputDStream[SparkFlumeEvent] = { + createPollingStream(ssc, addresses, storageLevel, + DEFAULT_POLLING_BATCH_SIZE, DEFAULT_POLLING_PARALLELISM) + } + + /** + * Creates an input stream that is to be used with the Spark Sink deployed on a Flume agent. + * This stream will poll the sink for data and will pull events as they are available. + * @param addresses List of InetSocketAddresses representing the hosts to connect to. + * @param maxBatchSize Maximum number of events to be pulled from the Spark sink in a + * single RPC call + * @param parallelism Number of concurrent requests this stream should send to the sink. Note + * that having a higher number of requests concurrently being pulled will + * result in this stream using more threads + * @param storageLevel Storage level to use for storing the received objects + */ + @Experimental + def createPollingStream( + ssc: StreamingContext, + addresses: Seq[InetSocketAddress], + storageLevel: StorageLevel, + maxBatchSize: Int, + parallelism: Int + ): ReceiverInputDStream[SparkFlumeEvent] = { + new FlumePollingInputDStream[SparkFlumeEvent](ssc, addresses, maxBatchSize, + parallelism, storageLevel) + } + + /** + * Creates an input stream that is to be used with the Spark Sink deployed on a Flume agent. + * This stream will poll the sink for data and will pull events as they are available. + * This stream will use a batch size of 1000 events and run 5 threads to pull data. + * @param hostname Hostname of the host on which the Spark Sink is running + * @param port Port of the host at which the Spark Sink is listening + */ + @Experimental + def createPollingStream( + jssc: JavaStreamingContext, + hostname: String, + port: Int + ): JavaReceiverInputDStream[SparkFlumeEvent] = { + createPollingStream(jssc, hostname, port, StorageLevel.MEMORY_AND_DISK_SER_2) + } + + /** + * Creates an input stream that is to be used with the Spark Sink deployed on a Flume agent. + * This stream will poll the sink for data and will pull events as they are available. + * This stream will use a batch size of 1000 events and run 5 threads to pull data. + * @param hostname Hostname of the host on which the Spark Sink is running + * @param port Port of the host at which the Spark Sink is listening + * @param storageLevel Storage level to use for storing the received objects + */ + @Experimental + def createPollingStream( + jssc: JavaStreamingContext, + hostname: String, + port: Int, + storageLevel: StorageLevel + ): JavaReceiverInputDStream[SparkFlumeEvent] = { + createPollingStream(jssc, Array(new InetSocketAddress(hostname, port)), storageLevel) + } + + /** + * Creates an input stream that is to be used with the Spark Sink deployed on a Flume agent. + * This stream will poll the sink for data and will pull events as they are available. + * This stream will use a batch size of 1000 events and run 5 threads to pull data. + * @param addresses List of InetSocketAddresses on which the Spark Sink is running. + * @param storageLevel Storage level to use for storing the received objects + */ + @Experimental + def createPollingStream( + jssc: JavaStreamingContext, + addresses: Array[InetSocketAddress], + storageLevel: StorageLevel + ): JavaReceiverInputDStream[SparkFlumeEvent] = { + createPollingStream(jssc, addresses, storageLevel, + DEFAULT_POLLING_BATCH_SIZE, DEFAULT_POLLING_PARALLELISM) + } + + /** + * Creates an input stream that is to be used with the Spark Sink deployed on a Flume agent. + * This stream will poll the sink for data and will pull events as they are available. + * @param addresses List of InetSocketAddresses on which the Spark Sink is running + * @param maxBatchSize The maximum number of events to be pulled from the Spark sink in a + * single RPC call + * @param parallelism Number of concurrent requests this stream should send to the sink. Note + * that having a higher number of requests concurrently being pulled will + * result in this stream using more threads + * @param storageLevel Storage level to use for storing the received objects + */ + @Experimental + def createPollingStream( + jssc: JavaStreamingContext, + addresses: Array[InetSocketAddress], + storageLevel: StorageLevel, + maxBatchSize: Int, + parallelism: Int + ): JavaReceiverInputDStream[SparkFlumeEvent] = { + createPollingStream(jssc.ssc, addresses, storageLevel, maxBatchSize, parallelism) + } } diff --git a/external/flume/src/test/java/org/apache/spark/streaming/flume/JavaFlumePollingStreamSuite.java b/external/flume/src/test/java/org/apache/spark/streaming/flume/JavaFlumePollingStreamSuite.java new file mode 100644 index 0000000000000..79c5b91654b42 --- /dev/null +++ b/external/flume/src/test/java/org/apache/spark/streaming/flume/JavaFlumePollingStreamSuite.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.flume; + +import java.net.InetSocketAddress; + +import org.apache.spark.storage.StorageLevel; +import org.apache.spark.streaming.LocalJavaStreamingContext; + +import org.apache.spark.streaming.api.java.JavaReceiverInputDStream; +import org.junit.Test; + +public class JavaFlumePollingStreamSuite extends LocalJavaStreamingContext { + @Test + public void testFlumeStream() { + // tests the API, does not actually test data receiving + InetSocketAddress[] addresses = new InetSocketAddress[] { + new InetSocketAddress("localhost", 12345) + }; + JavaReceiverInputDStream test1 = + FlumeUtils.createPollingStream(ssc, "localhost", 12345); + JavaReceiverInputDStream test2 = FlumeUtils.createPollingStream( + ssc, "localhost", 12345, StorageLevel.MEMORY_AND_DISK_SER_2()); + JavaReceiverInputDStream test3 = FlumeUtils.createPollingStream( + ssc, addresses, StorageLevel.MEMORY_AND_DISK_SER_2()); + JavaReceiverInputDStream test4 = FlumeUtils.createPollingStream( + ssc, addresses, StorageLevel.MEMORY_AND_DISK_SER_2(), 100, 5); + } +} diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala new file mode 100644 index 0000000000000..27bf2ac962721 --- /dev/null +++ b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala @@ -0,0 +1,200 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.streaming.flume + +import java.net.InetSocketAddress +import java.util.concurrent.{Callable, ExecutorCompletionService, Executors} +import java.util.Random + +import scala.collection.JavaConversions._ +import scala.collection.mutable.{SynchronizedBuffer, ArrayBuffer} + +import org.apache.flume.Context +import org.apache.flume.channel.MemoryChannel +import org.apache.flume.conf.Configurables +import org.apache.flume.event.EventBuilder + +import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming.dstream.ReceiverInputDStream +import org.apache.spark.streaming.util.ManualClock +import org.apache.spark.streaming.{TestSuiteBase, TestOutputStream, StreamingContext} +import org.apache.spark.streaming.flume.sink._ + +class FlumePollingStreamSuite extends TestSuiteBase { + + val random = new Random() + /** Return a port in the ephemeral range. */ + def getTestPort = random.nextInt(16382) + 49152 + val batchCount = 5 + val eventsPerBatch = 100 + val totalEventsPerChannel = batchCount * eventsPerBatch + val channelCapacity = 5000 + + test("flume polling test") { + val testPort = getTestPort + // Set up the streaming context and input streams + val ssc = new StreamingContext(conf, batchDuration) + val flumeStream: ReceiverInputDStream[SparkFlumeEvent] = + FlumeUtils.createPollingStream(ssc, Seq(new InetSocketAddress("localhost", testPort)), + StorageLevel.MEMORY_AND_DISK, eventsPerBatch, 1) + val outputBuffer = new ArrayBuffer[Seq[SparkFlumeEvent]] + with SynchronizedBuffer[Seq[SparkFlumeEvent]] + val outputStream = new TestOutputStream(flumeStream, outputBuffer) + outputStream.register() + + // Start the channel and sink. + val context = new Context() + context.put("capacity", channelCapacity.toString) + context.put("transactionCapacity", "1000") + context.put("keep-alive", "0") + val channel = new MemoryChannel() + Configurables.configure(channel, context) + + val sink = new SparkSink() + context.put(SparkSinkConfig.CONF_HOSTNAME, "localhost") + context.put(SparkSinkConfig.CONF_PORT, String.valueOf(testPort)) + Configurables.configure(sink, context) + sink.setChannel(channel) + sink.start() + ssc.start() + + writeAndVerify(Seq(channel), ssc, outputBuffer) + assertChannelIsEmpty(channel) + sink.stop() + channel.stop() + } + + test("flume polling test multiple hosts") { + val testPort = getTestPort + // Set up the streaming context and input streams + val ssc = new StreamingContext(conf, batchDuration) + val addresses = Seq(testPort, testPort + 1).map(new InetSocketAddress("localhost", _)) + val flumeStream: ReceiverInputDStream[SparkFlumeEvent] = + FlumeUtils.createPollingStream(ssc, addresses, StorageLevel.MEMORY_AND_DISK, + eventsPerBatch, 5) + val outputBuffer = new ArrayBuffer[Seq[SparkFlumeEvent]] + with SynchronizedBuffer[Seq[SparkFlumeEvent]] + val outputStream = new TestOutputStream(flumeStream, outputBuffer) + outputStream.register() + + // Start the channel and sink. + val context = new Context() + context.put("capacity", channelCapacity.toString) + context.put("transactionCapacity", "1000") + context.put("keep-alive", "0") + val channel = new MemoryChannel() + Configurables.configure(channel, context) + + val channel2 = new MemoryChannel() + Configurables.configure(channel2, context) + + val sink = new SparkSink() + context.put(SparkSinkConfig.CONF_HOSTNAME, "localhost") + context.put(SparkSinkConfig.CONF_PORT, String.valueOf(testPort)) + Configurables.configure(sink, context) + sink.setChannel(channel) + sink.start() + + val sink2 = new SparkSink() + context.put(SparkSinkConfig.CONF_HOSTNAME, "localhost") + context.put(SparkSinkConfig.CONF_PORT, String.valueOf(testPort + 1)) + Configurables.configure(sink2, context) + sink2.setChannel(channel2) + sink2.start() + ssc.start() + writeAndVerify(Seq(channel, channel2), ssc, outputBuffer) + assertChannelIsEmpty(channel) + assertChannelIsEmpty(channel2) + sink.stop() + channel.stop() + } + + def writeAndVerify(channels: Seq[MemoryChannel], ssc: StreamingContext, + outputBuffer: ArrayBuffer[Seq[SparkFlumeEvent]]) { + val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] + val executor = Executors.newCachedThreadPool() + val executorCompletion = new ExecutorCompletionService[Void](executor) + channels.map(channel => { + executorCompletion.submit(new TxnSubmitter(channel, clock)) + }) + for (i <- 0 until channels.size) { + executorCompletion.take() + } + val startTime = System.currentTimeMillis() + while (outputBuffer.size < batchCount * channels.size && + System.currentTimeMillis() - startTime < 15000) { + logInfo("output.size = " + outputBuffer.size) + Thread.sleep(100) + } + val timeTaken = System.currentTimeMillis() - startTime + assert(timeTaken < 15000, "Operation timed out after " + timeTaken + " ms") + logInfo("Stopping context") + ssc.stop() + + val flattenedBuffer = outputBuffer.flatten + assert(flattenedBuffer.size === totalEventsPerChannel * channels.size) + var counter = 0 + for (k <- 0 until channels.size; i <- 0 until totalEventsPerChannel) { + val eventToVerify = EventBuilder.withBody((channels(k).getName + " - " + + String.valueOf(i)).getBytes("utf-8"), + Map[String, String]("test-" + i.toString -> "header")) + var found = false + var j = 0 + while (j < flattenedBuffer.size && !found) { + val strToCompare = new String(flattenedBuffer(j).event.getBody.array(), "utf-8") + if (new String(eventToVerify.getBody, "utf-8") == strToCompare && + eventToVerify.getHeaders.get("test-" + i.toString) + .equals(flattenedBuffer(j).event.getHeaders.get("test-" + i.toString))) { + found = true + counter += 1 + } + j += 1 + } + } + assert(counter === totalEventsPerChannel * channels.size) + } + + def assertChannelIsEmpty(channel: MemoryChannel) = { + val queueRemaining = channel.getClass.getDeclaredField("queueRemaining"); + queueRemaining.setAccessible(true) + val m = queueRemaining.get(channel).getClass.getDeclaredMethod("availablePermits") + assert(m.invoke(queueRemaining.get(channel)).asInstanceOf[Int] === 5000) + } + + private class TxnSubmitter(channel: MemoryChannel, clock: ManualClock) extends Callable[Void] { + override def call(): Void = { + var t = 0 + for (i <- 0 until batchCount) { + val tx = channel.getTransaction + tx.begin() + for (j <- 0 until eventsPerBatch) { + channel.put(EventBuilder.withBody((channel.getName + " - " + String.valueOf(t)).getBytes( + "utf-8"), + Map[String, String]("test-" + t.toString -> "header"))) + t += 1 + } + tx.commit() + tx.close() + Thread.sleep(500) // Allow some time for the events to reach + clock.addToTime(batchDuration.milliseconds) + } + null + } + } +} diff --git a/external/kafka/pom.xml b/external/kafka/pom.xml index 4762c50685a93..4e2275ab238f7 100644 --- a/external/kafka/pom.xml +++ b/external/kafka/pom.xml @@ -28,7 +28,7 @@ org.apache.spark spark-streaming-kafka_2.10 - streaming-kafka + streaming-kafka jar Spark Project External Kafka @@ -68,8 +68,18 @@ org.slf4j slf4j-simple + + org.apache.zookeeper + zookeeper + + + net.sf.jopt-simple + jopt-simple + 3.2 + test + org.scalatest scalatest_${scala.binary.version} @@ -80,6 +90,11 @@ scalacheck_${scala.binary.version} test + + junit + junit + test + com.novocode junit-interface diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala index 38095e88dcea9..e20e2c8f26991 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala @@ -18,7 +18,7 @@ package org.apache.spark.streaming.kafka import scala.collection.Map -import scala.reflect.ClassTag +import scala.reflect.{classTag, ClassTag} import java.util.Properties import java.util.concurrent.Executors @@ -48,8 +48,8 @@ private[streaming] class KafkaInputDStream[ K: ClassTag, V: ClassTag, - U <: Decoder[_]: Manifest, - T <: Decoder[_]: Manifest]( + U <: Decoder[_]: ClassTag, + T <: Decoder[_]: ClassTag]( @transient ssc_ : StreamingContext, kafkaParams: Map[String, String], topics: Map[String, Int], @@ -66,8 +66,8 @@ private[streaming] class KafkaReceiver[ K: ClassTag, V: ClassTag, - U <: Decoder[_]: Manifest, - T <: Decoder[_]: Manifest]( + U <: Decoder[_]: ClassTag, + T <: Decoder[_]: ClassTag]( kafkaParams: Map[String, String], topics: Map[String, Int], storageLevel: StorageLevel @@ -103,10 +103,10 @@ class KafkaReceiver[ tryZookeeperConsumerGroupCleanup(zkConnect, kafkaParams("group.id")) } - val keyDecoder = manifest[U].runtimeClass.getConstructor(classOf[VerifiableProperties]) + val keyDecoder = classTag[U].runtimeClass.getConstructor(classOf[VerifiableProperties]) .newInstance(consumerConfig.props) .asInstanceOf[Decoder[K]] - val valueDecoder = manifest[T].runtimeClass.getConstructor(classOf[VerifiableProperties]) + val valueDecoder = classTag[T].runtimeClass.getConstructor(classOf[VerifiableProperties]) .newInstance(consumerConfig.props) .asInstanceOf[Decoder[V]] diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala index 86bb91f362d29..48668f763e41e 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala @@ -65,7 +65,7 @@ object KafkaUtils { * in its own thread. * @param storageLevel Storage level to use for storing the received objects */ - def createStream[K: ClassTag, V: ClassTag, U <: Decoder[_]: Manifest, T <: Decoder[_]: Manifest]( + def createStream[K: ClassTag, V: ClassTag, U <: Decoder[_]: ClassTag, T <: Decoder[_]: ClassTag]( ssc: StreamingContext, kafkaParams: Map[String, String], topics: Map[String, Int], @@ -89,8 +89,6 @@ object KafkaUtils { groupId: String, topics: JMap[String, JInt] ): JavaPairReceiverInputDStream[String, String] = { - implicit val cmt: ClassTag[String] = - implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[String]] createStream(jssc.ssc, zkQuorum, groupId, Map(topics.mapValues(_.intValue()).toSeq: _*)) } @@ -111,8 +109,6 @@ object KafkaUtils { topics: JMap[String, JInt], storageLevel: StorageLevel ): JavaPairReceiverInputDStream[String, String] = { - implicit val cmt: ClassTag[String] = - implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[String]] createStream(jssc.ssc, zkQuorum, groupId, Map(topics.mapValues(_.intValue()).toSeq: _*), storageLevel) } @@ -140,13 +136,11 @@ object KafkaUtils { topics: JMap[String, JInt], storageLevel: StorageLevel ): JavaPairReceiverInputDStream[K, V] = { - implicit val keyCmt: ClassTag[K] = - implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[K]] - implicit val valueCmt: ClassTag[V] = - implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[V]] + implicit val keyCmt: ClassTag[K] = ClassTag(keyTypeClass) + implicit val valueCmt: ClassTag[V] = ClassTag(valueTypeClass) - implicit val keyCmd: Manifest[U] = implicitly[Manifest[AnyRef]].asInstanceOf[Manifest[U]] - implicit val valueCmd: Manifest[T] = implicitly[Manifest[AnyRef]].asInstanceOf[Manifest[T]] + implicit val keyCmd: ClassTag[U] = ClassTag(keyDecoderClass) + implicit val valueCmd: ClassTag[T] = ClassTag(valueDecoderClass) createStream[K, V, U, T]( jssc.ssc, kafkaParams.toMap, Map(topics.mapValues(_.intValue()).toSeq: _*), storageLevel) diff --git a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java index 9f8046bf00f8f..0571454c01dae 100644 --- a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java +++ b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java @@ -17,31 +17,118 @@ package org.apache.spark.streaming.kafka; +import java.io.Serializable; import java.util.HashMap; +import java.util.List; + +import scala.Predef; +import scala.Tuple2; +import scala.collection.JavaConverters; + +import junit.framework.Assert; -import org.apache.spark.streaming.api.java.JavaPairReceiverInputDStream; -import org.junit.Test; -import com.google.common.collect.Maps; import kafka.serializer.StringDecoder; + +import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.function.Function; import org.apache.spark.storage.StorageLevel; +import org.apache.spark.streaming.Duration; import org.apache.spark.streaming.LocalJavaStreamingContext; +import org.apache.spark.streaming.api.java.JavaDStream; +import org.apache.spark.streaming.api.java.JavaPairDStream; +import org.apache.spark.streaming.api.java.JavaStreamingContext; + +import org.junit.Test; +import org.junit.After; +import org.junit.Before; + +public class JavaKafkaStreamSuite extends LocalJavaStreamingContext implements Serializable { + private transient KafkaStreamSuite testSuite = new KafkaStreamSuite(); + + @Before + @Override + public void setUp() { + testSuite.beforeFunction(); + System.clearProperty("spark.driver.port"); + //System.setProperty("spark.streaming.clock", "org.apache.spark.streaming.util.SystemClock"); + ssc = new JavaStreamingContext("local[2]", "test", new Duration(1000)); + } + + @After + @Override + public void tearDown() { + ssc.stop(); + ssc = null; + System.clearProperty("spark.driver.port"); + testSuite.afterFunction(); + } -public class JavaKafkaStreamSuite extends LocalJavaStreamingContext { @Test - public void testKafkaStream() { - HashMap topics = Maps.newHashMap(); - - // tests the API, does not actually test data receiving - JavaPairReceiverInputDStream test1 = - KafkaUtils.createStream(ssc, "localhost:12345", "group", topics); - JavaPairReceiverInputDStream test2 = KafkaUtils.createStream(ssc, "localhost:12345", "group", topics, - StorageLevel.MEMORY_AND_DISK_SER_2()); - - HashMap kafkaParams = Maps.newHashMap(); - kafkaParams.put("zookeeper.connect", "localhost:12345"); - kafkaParams.put("group.id","consumer-group"); - JavaPairReceiverInputDStream test3 = KafkaUtils.createStream(ssc, - String.class, String.class, StringDecoder.class, StringDecoder.class, - kafkaParams, topics, StorageLevel.MEMORY_AND_DISK_SER_2()); + public void testKafkaStream() throws InterruptedException { + String topic = "topic1"; + HashMap topics = new HashMap(); + topics.put(topic, 1); + + HashMap sent = new HashMap(); + sent.put("a", 5); + sent.put("b", 3); + sent.put("c", 10); + + testSuite.createTopic(topic); + HashMap tmp = new HashMap(sent); + testSuite.produceAndSendMessage(topic, + JavaConverters.mapAsScalaMapConverter(tmp).asScala().toMap( + Predef.>conforms())); + + HashMap kafkaParams = new HashMap(); + kafkaParams.put("zookeeper.connect", testSuite.zkConnect()); + kafkaParams.put("group.id", "test-consumer-" + KafkaTestUtils.random().nextInt(10000)); + kafkaParams.put("auto.offset.reset", "smallest"); + + JavaPairDStream stream = KafkaUtils.createStream(ssc, + String.class, + String.class, + StringDecoder.class, + StringDecoder.class, + kafkaParams, + topics, + StorageLevel.MEMORY_ONLY_SER()); + + final HashMap result = new HashMap(); + + JavaDStream words = stream.map( + new Function, String>() { + @Override + public String call(Tuple2 tuple2) throws Exception { + return tuple2._2(); + } + } + ); + + words.countByValue().foreachRDD( + new Function, Void>() { + @Override + public Void call(JavaPairRDD rdd) throws Exception { + List> ret = rdd.collect(); + for (Tuple2 r : ret) { + if (result.containsKey(r._1())) { + result.put(r._1(), result.get(r._1()) + r._2()); + } else { + result.put(r._1(), r._2()); + } + } + + return null; + } + } + ); + + ssc.start(); + ssc.awaitTermination(3000); + + Assert.assertEquals(sent.size(), result.size()); + for (String k : sent.keySet()) { + Assert.assertEquals(sent.get(k).intValue(), result.get(k).intValue()); + } } } diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala index e6f2c4a5cf5d1..c0b55e9340253 100644 --- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala +++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala @@ -17,28 +17,193 @@ package org.apache.spark.streaming.kafka -import kafka.serializer.StringDecoder +import java.io.File +import java.net.InetSocketAddress +import java.util.{Properties, Random} + +import scala.collection.mutable + +import kafka.admin.CreateTopicCommand +import kafka.common.TopicAndPartition +import kafka.producer.{KeyedMessage, ProducerConfig, Producer} +import kafka.utils.ZKStringSerializer +import kafka.serializer.{StringDecoder, StringEncoder} +import kafka.server.{KafkaConfig, KafkaServer} + +import org.I0Itec.zkclient.ZkClient + +import org.apache.zookeeper.server.ZooKeeperServer +import org.apache.zookeeper.server.NIOServerCnxnFactory + import org.apache.spark.streaming.{StreamingContext, TestSuiteBase} import org.apache.spark.storage.StorageLevel -import org.apache.spark.streaming.dstream.ReceiverInputDStream +import org.apache.spark.util.Utils class KafkaStreamSuite extends TestSuiteBase { + import KafkaTestUtils._ + + val zkConnect = "localhost:2181" + val zkConnectionTimeout = 6000 + val zkSessionTimeout = 6000 + + val brokerPort = 9092 + val brokerProps = getBrokerConfig(brokerPort, zkConnect) + val brokerConf = new KafkaConfig(brokerProps) + + protected var zookeeper: EmbeddedZookeeper = _ + protected var zkClient: ZkClient = _ + protected var server: KafkaServer = _ + protected var producer: Producer[String, String] = _ + + override def useManualClock = false + + override def beforeFunction() { + // Zookeeper server startup + zookeeper = new EmbeddedZookeeper(zkConnect) + logInfo("==================== 0 ====================") + zkClient = new ZkClient(zkConnect, zkSessionTimeout, zkConnectionTimeout, ZKStringSerializer) + logInfo("==================== 1 ====================") - test("kafka input stream") { + // Kafka broker startup + server = new KafkaServer(brokerConf) + logInfo("==================== 2 ====================") + server.startup() + logInfo("==================== 3 ====================") + Thread.sleep(2000) + logInfo("==================== 4 ====================") + super.beforeFunction() + } + + override def afterFunction() { + producer.close() + server.shutdown() + brokerConf.logDirs.foreach { f => Utils.deleteRecursively(new File(f)) } + + zkClient.close() + zookeeper.shutdown() + + super.afterFunction() + } + + test("Kafka input stream") { val ssc = new StreamingContext(master, framework, batchDuration) - val topics = Map("my-topic" -> 1) - - // tests the API, does not actually test data receiving - val test1: ReceiverInputDStream[(String, String)] = - KafkaUtils.createStream(ssc, "localhost:1234", "group", topics) - val test2: ReceiverInputDStream[(String, String)] = - KafkaUtils.createStream(ssc, "localhost:12345", "group", topics, StorageLevel.MEMORY_AND_DISK_SER_2) - val kafkaParams = Map("zookeeper.connect"->"localhost:12345","group.id"->"consumer-group") - val test3: ReceiverInputDStream[(String, String)] = - KafkaUtils.createStream[String, String, StringDecoder, StringDecoder]( - ssc, kafkaParams, topics, StorageLevel.MEMORY_AND_DISK_SER_2) - - // TODO: Actually test receiving data + val topic = "topic1" + val sent = Map("a" -> 5, "b" -> 3, "c" -> 10) + createTopic(topic) + produceAndSendMessage(topic, sent) + + val kafkaParams = Map("zookeeper.connect" -> zkConnect, + "group.id" -> s"test-consumer-${random.nextInt(10000)}", + "auto.offset.reset" -> "smallest") + + val stream = KafkaUtils.createStream[String, String, StringDecoder, StringDecoder]( + ssc, + kafkaParams, + Map(topic -> 1), + StorageLevel.MEMORY_ONLY) + val result = new mutable.HashMap[String, Long]() + stream.map { case (k, v) => v } + .countByValue() + .foreachRDD { r => + val ret = r.collect() + ret.toMap.foreach { kv => + val count = result.getOrElseUpdate(kv._1, 0) + kv._2 + result.put(kv._1, count) + } + } + ssc.start() + ssc.awaitTermination(3000) + + assert(sent.size === result.size) + sent.keys.foreach { k => assert(sent(k) === result(k).toInt) } + ssc.stop() } + + private def createTestMessage(topic: String, sent: Map[String, Int]) + : Seq[KeyedMessage[String, String]] = { + val messages = for ((s, freq) <- sent; i <- 0 until freq) yield { + new KeyedMessage[String, String](topic, s) + } + messages.toSeq + } + + def createTopic(topic: String) { + CreateTopicCommand.createTopic(zkClient, topic, 1, 1, "0") + logInfo("==================== 5 ====================") + // wait until metadata is propagated + waitUntilMetadataIsPropagated(Seq(server), topic, 0, 1000) + } + + def produceAndSendMessage(topic: String, sent: Map[String, Int]) { + val brokerAddr = brokerConf.hostName + ":" + brokerConf.port + producer = new Producer[String, String](new ProducerConfig(getProducerConfig(brokerAddr))) + producer.send(createTestMessage(topic, sent): _*) + logInfo("==================== 6 ====================") + } +} + +object KafkaTestUtils { + val random = new Random() + + def getBrokerConfig(port: Int, zkConnect: String): Properties = { + val props = new Properties() + props.put("broker.id", "0") + props.put("host.name", "localhost") + props.put("port", port.toString) + props.put("log.dir", Utils.createTempDir().getAbsolutePath) + props.put("zookeeper.connect", zkConnect) + props.put("log.flush.interval.messages", "1") + props.put("replica.socket.timeout.ms", "1500") + props + } + + def getProducerConfig(brokerList: String): Properties = { + val props = new Properties() + props.put("metadata.broker.list", brokerList) + props.put("serializer.class", classOf[StringEncoder].getName) + props + } + + def waitUntilTrue(condition: () => Boolean, waitTime: Long): Boolean = { + val startTime = System.currentTimeMillis() + while (true) { + if (condition()) + return true + if (System.currentTimeMillis() > startTime + waitTime) + return false + Thread.sleep(waitTime.min(100L)) + } + // Should never go to here + throw new RuntimeException("unexpected error") + } + + def waitUntilMetadataIsPropagated(servers: Seq[KafkaServer], topic: String, partition: Int, + timeout: Long) { + assert(waitUntilTrue(() => + servers.foldLeft(true)(_ && _.apis.leaderCache.keySet.contains( + TopicAndPartition(topic, partition))), timeout), + s"Partition [$topic, $partition] metadata not propagated after timeout") + } + + class EmbeddedZookeeper(val zkConnect: String) { + val random = new Random() + val snapshotDir = Utils.createTempDir() + val logDir = Utils.createTempDir() + + val zookeeper = new ZooKeeperServer(snapshotDir, logDir, 500) + val (ip, port) = { + val splits = zkConnect.split(":") + (splits(0), splits(1).toInt) + } + val factory = new NIOServerCnxnFactory() + factory.configure(new InetSocketAddress(ip, port), 16) + factory.startup(zookeeper) + + def shutdown() { + factory.shutdown() + Utils.deleteRecursively(snapshotDir) + Utils.deleteRecursively(logDir) + } + } } diff --git a/external/mqtt/pom.xml b/external/mqtt/pom.xml index 32c530e600ce0..dc48a08c93de2 100644 --- a/external/mqtt/pom.xml +++ b/external/mqtt/pom.xml @@ -28,7 +28,7 @@ org.apache.spark spark-streaming-mqtt_2.10 - streaming-mqtt + streaming-mqtt jar Spark Project External MQTT @@ -67,6 +67,11 @@ scalacheck_${scala.binary.version} test + + junit + junit + test + com.novocode junit-interface diff --git a/external/twitter/pom.xml b/external/twitter/pom.xml index 637adb0f00da0..b93ad016f84f0 100644 --- a/external/twitter/pom.xml +++ b/external/twitter/pom.xml @@ -28,7 +28,7 @@ org.apache.spark spark-streaming-twitter_2.10 - streaming-twitter + streaming-twitter jar Spark Project External Twitter @@ -62,6 +62,11 @@ scalacheck_${scala.binary.version} test + + junit + junit + test + com.novocode junit-interface diff --git a/external/zeromq/pom.xml b/external/zeromq/pom.xml index e4d758a04a4cd..22c1fff23d9a2 100644 --- a/external/zeromq/pom.xml +++ b/external/zeromq/pom.xml @@ -28,7 +28,7 @@ org.apache.spark spark-streaming-zeromq_2.10 - streaming-zeromq + streaming-zeromq jar Spark Project External ZeroMQ @@ -62,6 +62,11 @@ scalacheck_${scala.binary.version} test + + junit + junit + test + com.novocode junit-interface diff --git a/extras/java8-tests/pom.xml b/extras/java8-tests/pom.xml index 3eade411b38b7..5308bb4e440ea 100644 --- a/extras/java8-tests/pom.xml +++ b/extras/java8-tests/pom.xml @@ -50,6 +50,11 @@ ${project.version} test-jar + + junit + junit + test + com.novocode junit-interface diff --git a/extras/kinesis-asl/pom.xml b/extras/kinesis-asl/pom.xml new file mode 100644 index 0000000000000..a54b34235dfb4 --- /dev/null +++ b/extras/kinesis-asl/pom.xml @@ -0,0 +1,96 @@ + + + + 4.0.0 + + org.apache.spark + spark-parent + 1.1.0-SNAPSHOT + ../../pom.xml + + + + org.apache.spark + spark-streaming-kinesis-asl_2.10 + jar + Spark Kinesis Integration + + + kinesis-asl + + + + + org.apache.spark + spark-streaming_${scala.binary.version} + ${project.version} + + + org.apache.spark + spark-streaming_${scala.binary.version} + ${project.version} + test-jar + test + + + com.amazonaws + amazon-kinesis-client + ${aws.kinesis.client.version} + + + com.amazonaws + aws-java-sdk + ${aws.java.sdk.version} + + + org.scalatest + scalatest_${scala.binary.version} + test + + + org.mockito + mockito-all + test + + + org.scalacheck + scalacheck_${scala.binary.version} + test + + + org.easymock + easymockclassextension + test + + + com.novocode + junit-interface + test + + + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + + org.scalatest + scalatest-maven-plugin + + + + diff --git a/extras/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java b/extras/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java new file mode 100644 index 0000000000000..a8b907b241893 --- /dev/null +++ b/extras/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java @@ -0,0 +1,180 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.examples.streaming; + +import java.util.ArrayList; +import java.util.List; +import java.util.regex.Pattern; + +import org.apache.log4j.Logger; +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.function.FlatMapFunction; +import org.apache.spark.api.java.function.Function2; +import org.apache.spark.api.java.function.PairFunction; +import org.apache.spark.storage.StorageLevel; +import org.apache.spark.streaming.Duration; +import org.apache.spark.streaming.api.java.JavaDStream; +import org.apache.spark.streaming.api.java.JavaPairDStream; +import org.apache.spark.streaming.api.java.JavaStreamingContext; +import org.apache.spark.streaming.kinesis.KinesisUtils; + +import scala.Tuple2; + +import com.amazonaws.auth.DefaultAWSCredentialsProviderChain; +import com.amazonaws.services.kinesis.AmazonKinesisClient; +import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream; +import com.google.common.collect.Lists; + +/** + * Java-friendly Kinesis Spark Streaming WordCount example + * + * See http://spark.apache.org/docs/latest/streaming-kinesis.html for more details + * on the Kinesis Spark Streaming integration. + * + * This example spins up 1 Kinesis Worker (Spark Streaming Receiver) per shard + * for the given stream. + * It then starts pulling from the last checkpointed sequence number of the given + * and . + * + * Valid endpoint urls: http://docs.aws.amazon.com/general/latest/gr/rande.html#ak_region + * + * This code uses the DefaultAWSCredentialsProviderChain and searches for credentials + * in the following order of precedence: + * Environment Variables - AWS_ACCESS_KEY_ID and AWS_SECRET_KEY + * Java System Properties - aws.accessKeyId and aws.secretKey + * Credential profiles file - default location (~/.aws/credentials) shared by all AWS SDKs + * Instance profile credentials - delivered through the Amazon EC2 metadata service + * + * Usage: JavaKinesisWordCountASL + * is the name of the Kinesis stream (ie. mySparkStream) + * is the endpoint of the Kinesis service + * (ie. https://kinesis.us-east-1.amazonaws.com) + * + * Example: + * $ export AWS_ACCESS_KEY_ID= + * $ export AWS_SECRET_KEY= + * $ $SPARK_HOME/bin/run-example \ + * org.apache.spark.examples.streaming.JavaKinesisWordCountASL mySparkStream \ + * https://kinesis.us-east-1.amazonaws.com + * + * There is a companion helper class called KinesisWordCountProducerASL which puts dummy data + * onto the Kinesis stream. + * Usage instructions for KinesisWordCountProducerASL are provided in the class definition. + */ +public final class JavaKinesisWordCountASL { + private static final Pattern WORD_SEPARATOR = Pattern.compile(" "); + private static final Logger logger = Logger.getLogger(JavaKinesisWordCountASL.class); + + /* Make the constructor private to enforce singleton */ + private JavaKinesisWordCountASL() { + } + + public static void main(String[] args) { + /* Check that all required args were passed in. */ + if (args.length < 2) { + System.err.println( + "|Usage: KinesisWordCount \n" + + "| is the name of the Kinesis stream\n" + + "| is the endpoint of the Kinesis service\n" + + "| (e.g. https://kinesis.us-east-1.amazonaws.com)\n"); + System.exit(1); + } + + StreamingExamples.setStreamingLogLevels(); + + /* Populate the appropriate variables from the given args */ + String streamName = args[0]; + String endpointUrl = args[1]; + /* Set the batch interval to a fixed 2000 millis (2 seconds) */ + Duration batchInterval = new Duration(2000); + + /* Create a Kinesis client in order to determine the number of shards for the given stream */ + AmazonKinesisClient kinesisClient = new AmazonKinesisClient( + new DefaultAWSCredentialsProviderChain()); + kinesisClient.setEndpoint(endpointUrl); + + /* Determine the number of shards from the stream */ + int numShards = kinesisClient.describeStream(streamName) + .getStreamDescription().getShards().size(); + + /* In this example, we're going to create 1 Kinesis Worker/Receiver/DStream for each shard */ + int numStreams = numShards; + + /* Must add 1 more thread than the number of receivers or the output won't show properly from the driver */ + int numSparkThreads = numStreams + 1; + + /* Setup the Spark config. */ + SparkConf sparkConfig = new SparkConf().setAppName("KinesisWordCount").setMaster( + "local[" + numSparkThreads + "]"); + + /* Kinesis checkpoint interval. Same as batchInterval for this example. */ + Duration checkpointInterval = batchInterval; + + /* Setup the StreamingContext */ + JavaStreamingContext jssc = new JavaStreamingContext(sparkConfig, batchInterval); + + /* Create the same number of Kinesis DStreams/Receivers as Kinesis stream's shards */ + List> streamsList = new ArrayList>(numStreams); + for (int i = 0; i < numStreams; i++) { + streamsList.add( + KinesisUtils.createStream(jssc, streamName, endpointUrl, checkpointInterval, + InitialPositionInStream.LATEST, StorageLevel.MEMORY_AND_DISK_2()) + ); + } + + /* Union all the streams if there is more than 1 stream */ + JavaDStream unionStreams; + if (streamsList.size() > 1) { + unionStreams = jssc.union(streamsList.get(0), streamsList.subList(1, streamsList.size())); + } else { + /* Otherwise, just use the 1 stream */ + unionStreams = streamsList.get(0); + } + + /* + * Split each line of the union'd DStreams into multiple words using flatMap to produce the collection. + * Convert lines of byte[] to multiple Strings by first converting to String, then splitting on WORD_SEPARATOR. + */ + JavaDStream words = unionStreams.flatMap(new FlatMapFunction() { + @Override + public Iterable call(byte[] line) { + return Lists.newArrayList(WORD_SEPARATOR.split(new String(line))); + } + }); + + /* Map each word to a (word, 1) tuple, then reduce/aggregate by word. */ + JavaPairDStream wordCounts = words.mapToPair( + new PairFunction() { + @Override + public Tuple2 call(String s) { + return new Tuple2(s, 1); + } + }).reduceByKey(new Function2() { + @Override + public Integer call(Integer i1, Integer i2) { + return i1 + i2; + } + }); + + /* Print the first 10 wordCounts */ + wordCounts.print(); + + /* Start the streaming context and await termination */ + jssc.start(); + jssc.awaitTermination(); + } +} diff --git a/extras/kinesis-asl/src/main/resources/log4j.properties b/extras/kinesis-asl/src/main/resources/log4j.properties new file mode 100644 index 0000000000000..97348fb5b6123 --- /dev/null +++ b/extras/kinesis-asl/src/main/resources/log4j.properties @@ -0,0 +1,37 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +log4j.rootCategory=WARN, console + +# File appender +log4j.appender.file=org.apache.log4j.FileAppender +log4j.appender.file.append=false +log4j.appender.file.file=target/unit-tests.log +log4j.appender.file.layout=org.apache.log4j.PatternLayout +log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %p %c{1}: %m%n + +# Console appender +log4j.appender.console=org.apache.log4j.ConsoleAppender +log4j.appender.console.target=System.out +log4j.appender.console.layout=org.apache.log4j.PatternLayout +log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n + +# Settings to quiet third party logs that are too verbose +log4j.logger.org.eclipse.jetty=WARN +log4j.logger.org.eclipse.jetty.util.component.AbstractLifeCycle=ERROR +log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=INFO +log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=INFO \ No newline at end of file diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala new file mode 100644 index 0000000000000..d03edf8b30a9f --- /dev/null +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala @@ -0,0 +1,251 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.streaming + +import java.nio.ByteBuffer +import scala.util.Random +import org.apache.spark.Logging +import org.apache.spark.SparkConf +import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming.Milliseconds +import org.apache.spark.streaming.StreamingContext +import org.apache.spark.streaming.StreamingContext.toPairDStreamFunctions +import org.apache.spark.streaming.kinesis.KinesisUtils +import com.amazonaws.auth.DefaultAWSCredentialsProviderChain +import com.amazonaws.services.kinesis.AmazonKinesisClient +import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream +import com.amazonaws.services.kinesis.model.PutRecordRequest +import org.apache.log4j.Logger +import org.apache.log4j.Level + +/** + * Kinesis Spark Streaming WordCount example. + * + * See http://spark.apache.org/docs/latest/streaming-kinesis.html for more details on + * the Kinesis Spark Streaming integration. + * + * This example spins up 1 Kinesis Worker (Spark Streaming Receiver) per shard + * for the given stream. + * It then starts pulling from the last checkpointed sequence number of the given + * and . + * + * Valid endpoint urls: http://docs.aws.amazon.com/general/latest/gr/rande.html#ak_region + * + * This code uses the DefaultAWSCredentialsProviderChain and searches for credentials + * in the following order of precedence: + * Environment Variables - AWS_ACCESS_KEY_ID and AWS_SECRET_KEY + * Java System Properties - aws.accessKeyId and aws.secretKey + * Credential profiles file - default location (~/.aws/credentials) shared by all AWS SDKs + * Instance profile credentials - delivered through the Amazon EC2 metadata service + * + * Usage: KinesisWordCountASL + * is the name of the Kinesis stream (ie. mySparkStream) + * is the endpoint of the Kinesis service + * (ie. https://kinesis.us-east-1.amazonaws.com) + * + * Example: + * $ export AWS_ACCESS_KEY_ID= + * $ export AWS_SECRET_KEY= + * $ $SPARK_HOME/bin/run-example \ + * org.apache.spark.examples.streaming.KinesisWordCountASL mySparkStream \ + * https://kinesis.us-east-1.amazonaws.com + * + * There is a companion helper class below called KinesisWordCountProducerASL which puts + * dummy data onto the Kinesis stream. + * Usage instructions for KinesisWordCountProducerASL are provided in that class definition. + */ +object KinesisWordCountASL extends Logging { + def main(args: Array[String]) { + /* Check that all required args were passed in. */ + if (args.length < 2) { + System.err.println( + """ + |Usage: KinesisWordCount + | is the name of the Kinesis stream + | is the endpoint of the Kinesis service + | (e.g. https://kinesis.us-east-1.amazonaws.com) + """.stripMargin) + System.exit(1) + } + + StreamingExamples.setStreamingLogLevels() + + /* Populate the appropriate variables from the given args */ + val Array(streamName, endpointUrl) = args + + /* Determine the number of shards from the stream */ + val kinesisClient = new AmazonKinesisClient(new DefaultAWSCredentialsProviderChain()) + kinesisClient.setEndpoint(endpointUrl) + val numShards = kinesisClient.describeStream(streamName).getStreamDescription().getShards() + .size() + + /* In this example, we're going to create 1 Kinesis Worker/Receiver/DStream for each shard. */ + val numStreams = numShards + + /* + * numSparkThreads should be 1 more thread than the number of receivers. + * This leaves one thread available for actually processing the data. + */ + val numSparkThreads = numStreams + 1 + + /* Setup the and SparkConfig and StreamingContext */ + /* Spark Streaming batch interval */ + val batchInterval = Milliseconds(2000) + val sparkConfig = new SparkConf().setAppName("KinesisWordCount") + .setMaster(s"local[$numSparkThreads]") + val ssc = new StreamingContext(sparkConfig, batchInterval) + + /* Kinesis checkpoint interval. Same as batchInterval for this example. */ + val kinesisCheckpointInterval = batchInterval + + /* Create the same number of Kinesis DStreams/Receivers as Kinesis stream's shards */ + val kinesisStreams = (0 until numStreams).map { i => + KinesisUtils.createStream(ssc, streamName, endpointUrl, kinesisCheckpointInterval, + InitialPositionInStream.LATEST, StorageLevel.MEMORY_AND_DISK_2) + } + + /* Union all the streams */ + val unionStreams = ssc.union(kinesisStreams) + + /* Convert each line of Array[Byte] to String, split into words, and count them */ + val words = unionStreams.flatMap(byteArray => new String(byteArray) + .split(" ")) + + /* Map each word to a (word, 1) tuple so we can reduce/aggregate by key. */ + val wordCounts = words.map(word => (word, 1)).reduceByKey(_ + _) + + /* Print the first 10 wordCounts */ + wordCounts.print() + + /* Start the streaming context and await termination */ + ssc.start() + ssc.awaitTermination() + } +} + +/** + * Usage: KinesisWordCountProducerASL + * + * is the name of the Kinesis stream (ie. mySparkStream) + * is the endpoint of the Kinesis service + * (ie. https://kinesis.us-east-1.amazonaws.com) + * is the rate of records per second to put onto the stream + * is the rate of records per second to put onto the stream + * + * Example: + * $ export AWS_ACCESS_KEY_ID= + * $ export AWS_SECRET_KEY= + * $ $SPARK_HOME/bin/run-example \ + * org.apache.spark.examples.streaming.KinesisWordCountProducerASL mySparkStream \ + * https://kinesis.us-east-1.amazonaws.com 10 5 + */ +object KinesisWordCountProducerASL { + def main(args: Array[String]) { + if (args.length < 4) { + System.err.println("Usage: KinesisWordCountProducerASL " + + " ") + System.exit(1) + } + + StreamingExamples.setStreamingLogLevels() + + /* Populate the appropriate variables from the given args */ + val Array(stream, endpoint, recordsPerSecond, wordsPerRecord) = args + + /* Generate the records and return the totals */ + val totals = generate(stream, endpoint, recordsPerSecond.toInt, wordsPerRecord.toInt) + + /* Print the array of (index, total) tuples */ + println("Totals") + totals.foreach(total => println(total.toString())) + } + + def generate(stream: String, + endpoint: String, + recordsPerSecond: Int, + wordsPerRecord: Int): Seq[(Int, Int)] = { + + val MaxRandomInts = 10 + + /* Create the Kinesis client */ + val kinesisClient = new AmazonKinesisClient(new DefaultAWSCredentialsProviderChain()) + kinesisClient.setEndpoint(endpoint) + + println(s"Putting records onto stream $stream and endpoint $endpoint at a rate of" + + s" $recordsPerSecond records per second and $wordsPerRecord words per record"); + + val totals = new Array[Int](MaxRandomInts) + /* Put String records onto the stream per the given recordPerSec and wordsPerRecord */ + for (i <- 1 to 5) { + + /* Generate recordsPerSec records to put onto the stream */ + val records = (1 to recordsPerSecond.toInt).map { recordNum => + /* + * Randomly generate each wordsPerRec words between 0 (inclusive) + * and MAX_RANDOM_INTS (exclusive) + */ + val data = (1 to wordsPerRecord.toInt).map(x => { + /* Generate the random int */ + val randomInt = Random.nextInt(MaxRandomInts) + + /* Keep track of the totals */ + totals(randomInt) += 1 + + randomInt.toString() + }).mkString(" ") + + /* Create a partitionKey based on recordNum */ + val partitionKey = s"partitionKey-$recordNum" + + /* Create a PutRecordRequest with an Array[Byte] version of the data */ + val putRecordRequest = new PutRecordRequest().withStreamName(stream) + .withPartitionKey(partitionKey) + .withData(ByteBuffer.wrap(data.getBytes())); + + /* Put the record onto the stream and capture the PutRecordResult */ + val putRecordResult = kinesisClient.putRecord(putRecordRequest); + } + + /* Sleep for a second */ + Thread.sleep(1000) + println("Sent " + recordsPerSecond + " records") + } + + /* Convert the totals to (index, total) tuple */ + (0 to (MaxRandomInts - 1)).zip(totals) + } +} + +/** + * Utility functions for Spark Streaming examples. + * This has been lifted from the examples/ project to remove the circular dependency. + */ +object StreamingExamples extends Logging { + + /** Set reasonable logging levels for streaming if the user has not configured log4j. */ + def setStreamingLogLevels() { + val log4jInitialized = Logger.getRootLogger.getAllAppenders.hasMoreElements + if (!log4jInitialized) { + // We first log something to initialize Spark's default logging, then we override the + // logging level. + logInfo("Setting log level to [WARN] for streaming example." + + " To override add a custom log4j.properties to the classpath.") + Logger.getRootLogger.setLevel(Level.WARN) + } + } +} diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointState.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointState.scala new file mode 100644 index 0000000000000..0b80b611cdce7 --- /dev/null +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointState.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.streaming.kinesis + +import org.apache.spark.Logging +import org.apache.spark.streaming.Duration +import org.apache.spark.streaming.util.Clock +import org.apache.spark.streaming.util.ManualClock +import org.apache.spark.streaming.util.SystemClock + +/** + * This is a helper class for managing checkpoint clocks. + * + * @param checkpointInterval + * @param currentClock. Default to current SystemClock if none is passed in (mocking purposes) + */ +private[kinesis] class KinesisCheckpointState( + checkpointInterval: Duration, + currentClock: Clock = new SystemClock()) + extends Logging { + + /* Initialize the checkpoint clock using the given currentClock + checkpointInterval millis */ + val checkpointClock = new ManualClock() + checkpointClock.setTime(currentClock.currentTime() + checkpointInterval.milliseconds) + + /** + * Check if it's time to checkpoint based on the current time and the derived time + * for the next checkpoint + * + * @return true if it's time to checkpoint + */ + def shouldCheckpoint(): Boolean = { + new SystemClock().currentTime() > checkpointClock.currentTime() + } + + /** + * Advance the checkpoint clock by the checkpoint interval. + */ + def advanceCheckpoint() = { + checkpointClock.addToTime(checkpointInterval.milliseconds) + } +} diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala new file mode 100644 index 0000000000000..1bd1f324298e7 --- /dev/null +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.streaming.kinesis + +import java.net.InetAddress +import java.util.UUID + +import org.apache.spark.Logging +import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming.Duration +import org.apache.spark.streaming.receiver.Receiver + +import com.amazonaws.auth.AWSCredentialsProvider +import com.amazonaws.auth.DefaultAWSCredentialsProviderChain +import com.amazonaws.services.kinesis.clientlibrary.interfaces.IRecordProcessor +import com.amazonaws.services.kinesis.clientlibrary.interfaces.IRecordProcessorFactory +import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream +import com.amazonaws.services.kinesis.clientlibrary.lib.worker.KinesisClientLibConfiguration +import com.amazonaws.services.kinesis.clientlibrary.lib.worker.Worker + +/** + * Custom AWS Kinesis-specific implementation of Spark Streaming's Receiver. + * This implementation relies on the Kinesis Client Library (KCL) Worker as described here: + * https://github.com/awslabs/amazon-kinesis-client + * This is a custom receiver used with StreamingContext.receiverStream(Receiver) + * as described here: + * http://spark.apache.org/docs/latest/streaming-custom-receivers.html + * Instances of this class will get shipped to the Spark Streaming Workers + * to run within a Spark Executor. + * + * @param appName Kinesis application name. Kinesis Apps are mapped to Kinesis Streams + * by the Kinesis Client Library. If you change the App name or Stream name, + * the KCL will throw errors. This usually requires deleting the backing + * DynamoDB table with the same name this Kinesis application. + * @param streamName Kinesis stream name + * @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com) + * @param checkpointInterval Checkpoint interval for Kinesis checkpointing. + * See the Kinesis Spark Streaming documentation for more + * details on the different types of checkpoints. + * @param initialPositionInStream In the absence of Kinesis checkpoint info, this is the + * worker's initial starting position in the stream. + * The values are either the beginning of the stream + * per Kinesis' limit of 24 hours + * (InitialPositionInStream.TRIM_HORIZON) or + * the tip of the stream (InitialPositionInStream.LATEST). + * @param storageLevel Storage level to use for storing the received objects + * + * @return ReceiverInputDStream[Array[Byte]] + */ +private[kinesis] class KinesisReceiver( + appName: String, + streamName: String, + endpointUrl: String, + checkpointInterval: Duration, + initialPositionInStream: InitialPositionInStream, + storageLevel: StorageLevel) + extends Receiver[Array[Byte]](storageLevel) with Logging { receiver => + + /* + * The following vars are built in the onStart() method which executes in the Spark Worker after + * this code is serialized and shipped remotely. + */ + + /* + * workerId should be based on the ip address of the actual Spark Worker where this code runs + * (not the Driver's ip address.) + */ + var workerId: String = null + + /* + * This impl uses the DefaultAWSCredentialsProviderChain and searches for credentials + * in the following order of precedence: + * Environment Variables - AWS_ACCESS_KEY_ID and AWS_SECRET_KEY + * Java System Properties - aws.accessKeyId and aws.secretKey + * Credential profiles file at the default location (~/.aws/credentials) shared by all + * AWS SDKs and the AWS CLI + * Instance profile credentials delivered through the Amazon EC2 metadata service + */ + var credentialsProvider: AWSCredentialsProvider = null + + /* KCL config instance. */ + var kinesisClientLibConfiguration: KinesisClientLibConfiguration = null + + /* + * RecordProcessorFactory creates impls of IRecordProcessor. + * IRecordProcessor adapts the KCL to our Spark KinesisReceiver via the + * IRecordProcessor.processRecords() method. + * We're using our custom KinesisRecordProcessor in this case. + */ + var recordProcessorFactory: IRecordProcessorFactory = null + + /* + * Create a Kinesis Worker. + * This is the core client abstraction from the Kinesis Client Library (KCL). + * We pass the RecordProcessorFactory from above as well as the KCL config instance. + * A Kinesis Worker can process 1..* shards from the given stream - each with its + * own RecordProcessor. + */ + var worker: Worker = null + + /** + * This is called when the KinesisReceiver starts and must be non-blocking. + * The KCL creates and manages the receiving/processing thread pool through the Worker.run() + * method. + */ + override def onStart() { + workerId = InetAddress.getLocalHost.getHostAddress() + ":" + UUID.randomUUID() + credentialsProvider = new DefaultAWSCredentialsProviderChain() + kinesisClientLibConfiguration = new KinesisClientLibConfiguration(appName, streamName, + credentialsProvider, workerId).withKinesisEndpoint(endpointUrl) + .withInitialPositionInStream(initialPositionInStream).withTaskBackoffTimeMillis(500) + recordProcessorFactory = new IRecordProcessorFactory { + override def createProcessor: IRecordProcessor = new KinesisRecordProcessor(receiver, + workerId, new KinesisCheckpointState(checkpointInterval)) + } + worker = new Worker(recordProcessorFactory, kinesisClientLibConfiguration) + worker.run() + logInfo(s"Started receiver with workerId $workerId") + } + + /** + * This is called when the KinesisReceiver stops. + * The KCL worker.shutdown() method stops the receiving/processing threads. + * The KCL will do its best to drain and checkpoint any in-flight records upon shutdown. + */ + override def onStop() { + worker.shutdown() + logInfo(s"Shut down receiver with workerId $workerId") + workerId = null + credentialsProvider = null + kinesisClientLibConfiguration = null + recordProcessorFactory = null + worker = null + } +} diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala new file mode 100644 index 0000000000000..8ecc2d90160b1 --- /dev/null +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala @@ -0,0 +1,212 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.streaming.kinesis + +import java.util.List + +import scala.collection.JavaConversions.asScalaBuffer +import scala.util.Random + +import org.apache.spark.Logging + +import com.amazonaws.services.kinesis.clientlibrary.exceptions.InvalidStateException +import com.amazonaws.services.kinesis.clientlibrary.exceptions.KinesisClientLibDependencyException +import com.amazonaws.services.kinesis.clientlibrary.exceptions.ShutdownException +import com.amazonaws.services.kinesis.clientlibrary.exceptions.ThrottlingException +import com.amazonaws.services.kinesis.clientlibrary.interfaces.IRecordProcessor +import com.amazonaws.services.kinesis.clientlibrary.interfaces.IRecordProcessorCheckpointer +import com.amazonaws.services.kinesis.clientlibrary.types.ShutdownReason +import com.amazonaws.services.kinesis.model.Record + +/** + * Kinesis-specific implementation of the Kinesis Client Library (KCL) IRecordProcessor. + * This implementation operates on the Array[Byte] from the KinesisReceiver. + * The Kinesis Worker creates an instance of this KinesisRecordProcessor upon startup. + * + * @param receiver Kinesis receiver + * @param workerId for logging purposes + * @param checkpointState represents the checkpoint state including the next checkpoint time. + * It's injected here for mocking purposes. + */ +private[kinesis] class KinesisRecordProcessor( + receiver: KinesisReceiver, + workerId: String, + checkpointState: KinesisCheckpointState) extends IRecordProcessor with Logging { + + /* shardId to be populated during initialize() */ + var shardId: String = _ + + /** + * The Kinesis Client Library calls this method during IRecordProcessor initialization. + * + * @param shardId assigned by the KCL to this particular RecordProcessor. + */ + override def initialize(shardId: String) { + logInfo(s"Initialize: Initializing workerId $workerId with shardId $shardId") + this.shardId = shardId + } + + /** + * This method is called by the KCL when a batch of records is pulled from the Kinesis stream. + * This is the record-processing bridge between the KCL's IRecordProcessor.processRecords() + * and Spark Streaming's Receiver.store(). + * + * @param batch list of records from the Kinesis stream shard + * @param checkpointer used to update Kinesis when this batch has been processed/stored + * in the DStream + */ + override def processRecords(batch: List[Record], checkpointer: IRecordProcessorCheckpointer) { + if (!receiver.isStopped()) { + try { + /* + * Note: If we try to store the raw ByteBuffer from record.getData(), the Spark Streaming + * Receiver.store(ByteBuffer) attempts to deserialize the ByteBuffer using the + * internally-configured Spark serializer (kryo, etc). + * This is not desirable, so we instead store a raw Array[Byte] and decouple + * ourselves from Spark's internal serialization strategy. + */ + batch.foreach(record => receiver.store(record.getData().array())) + + logDebug(s"Stored: Worker $workerId stored ${batch.size} records for shardId $shardId") + + /* + * Checkpoint the sequence number of the last record successfully processed/stored + * in the batch. + * In this implementation, we're checkpointing after the given checkpointIntervalMillis. + * Note that this logic requires that processRecords() be called AND that it's time to + * checkpoint. I point this out because there is no background thread running the + * checkpointer. Checkpointing is tested and trigger only when a new batch comes in. + * If the worker is shutdown cleanly, checkpoint will happen (see shutdown() below). + * However, if the worker dies unexpectedly, a checkpoint may not happen. + * This could lead to records being processed more than once. + */ + if (checkpointState.shouldCheckpoint()) { + /* Perform the checkpoint */ + KinesisRecordProcessor.retryRandom(checkpointer.checkpoint(), 4, 100) + + /* Update the next checkpoint time */ + checkpointState.advanceCheckpoint() + + logDebug(s"Checkpoint: WorkerId $workerId completed checkpoint of ${batch.size}" + + s" records for shardId $shardId") + logDebug(s"Checkpoint: Next checkpoint is at " + + s" ${checkpointState.checkpointClock.currentTime()} for shardId $shardId") + } + } catch { + case e: Throwable => { + /* + * If there is a failure within the batch, the batch will not be checkpointed. + * This will potentially cause records since the last checkpoint to be processed + * more than once. + */ + logError(s"Exception: WorkerId $workerId encountered and exception while storing " + + " or checkpointing a batch for workerId $workerId and shardId $shardId.", e) + + /* Rethrow the exception to the Kinesis Worker that is managing this RecordProcessor.*/ + throw e + } + } + } else { + /* RecordProcessor has been stopped. */ + logInfo(s"Stopped: The Spark KinesisReceiver has stopped for workerId $workerId" + + s" and shardId $shardId. No more records will be processed.") + } + } + + /** + * Kinesis Client Library is shutting down this Worker for 1 of 2 reasons: + * 1) the stream is resharding by splitting or merging adjacent shards + * (ShutdownReason.TERMINATE) + * 2) the failed or latent Worker has stopped sending heartbeats for whatever reason + * (ShutdownReason.ZOMBIE) + * + * @param checkpointer used to perform a Kinesis checkpoint for ShutdownReason.TERMINATE + * @param reason for shutdown (ShutdownReason.TERMINATE or ShutdownReason.ZOMBIE) + */ + override def shutdown(checkpointer: IRecordProcessorCheckpointer, reason: ShutdownReason) { + logInfo(s"Shutdown: Shutting down workerId $workerId with reason $reason") + reason match { + /* + * TERMINATE Use Case. Checkpoint. + * Checkpoint to indicate that all records from the shard have been drained and processed. + * It's now OK to read from the new shards that resulted from a resharding event. + */ + case ShutdownReason.TERMINATE => + KinesisRecordProcessor.retryRandom(checkpointer.checkpoint(), 4, 100) + + /* + * ZOMBIE Use Case. NoOp. + * No checkpoint because other workers may have taken over and already started processing + * the same records. + * This may lead to records being processed more than once. + */ + case ShutdownReason.ZOMBIE => + + /* Unknown reason. NoOp */ + case _ => + } + } +} + +private[kinesis] object KinesisRecordProcessor extends Logging { + /** + * Retry the given amount of times with a random backoff time (millis) less than the + * given maxBackOffMillis + * + * @param expression expression to evalute + * @param numRetriesLeft number of retries left + * @param maxBackOffMillis: max millis between retries + * + * @return evaluation of the given expression + * @throws Unretryable exception, unexpected exception, + * or any exception that persists after numRetriesLeft reaches 0 + */ + @annotation.tailrec + def retryRandom[T](expression: => T, numRetriesLeft: Int, maxBackOffMillis: Int): T = { + util.Try { expression } match { + /* If the function succeeded, evaluate to x. */ + case util.Success(x) => x + /* If the function failed, either retry or throw the exception */ + case util.Failure(e) => e match { + /* Retry: Throttling or other Retryable exception has occurred */ + case _: ThrottlingException | _: KinesisClientLibDependencyException if numRetriesLeft > 1 + => { + val backOffMillis = Random.nextInt(maxBackOffMillis) + Thread.sleep(backOffMillis) + logError(s"Retryable Exception: Random backOffMillis=${backOffMillis}", e) + retryRandom(expression, numRetriesLeft - 1, maxBackOffMillis) + } + /* Throw: Shutdown has been requested by the Kinesis Client Library.*/ + case _: ShutdownException => { + logError(s"ShutdownException: Caught shutdown exception, skipping checkpoint.", e) + throw e + } + /* Throw: Non-retryable exception has occurred with the Kinesis Client Library */ + case _: InvalidStateException => { + logError(s"InvalidStateException: Cannot save checkpoint to the DynamoDB table used" + + s" by the Amazon Kinesis Client Library. Table likely doesn't exist.", e) + throw e + } + /* Throw: Unexpected exception has occurred */ + case _ => { + logError(s"Unexpected, non-retryable exception.", e) + throw e + } + } + } + } +} diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala new file mode 100644 index 0000000000000..713cac0e293c0 --- /dev/null +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.streaming.kinesis + +import org.apache.spark.annotation.Experimental +import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming.Duration +import org.apache.spark.streaming.StreamingContext +import org.apache.spark.streaming.api.java.JavaReceiverInputDStream +import org.apache.spark.streaming.api.java.JavaStreamingContext +import org.apache.spark.streaming.dstream.ReceiverInputDStream + +import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream + + +/** + * Helper class to create Amazon Kinesis Input Stream + * :: Experimental :: + */ +@Experimental +object KinesisUtils { + /** + * Create an InputDStream that pulls messages from a Kinesis stream. + * + * @param ssc StreamingContext object + * @param streamName Kinesis stream name + * @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com) + * @param checkpointInterval Checkpoint interval for Kinesis checkpointing. + * See the Kinesis Spark Streaming documentation for more + * details on the different types of checkpoints. + * @param initialPositionInStream In the absence of Kinesis checkpoint info, this is the + * worker's initial starting position in the stream. + * The values are either the beginning of the stream + * per Kinesis' limit of 24 hours + * (InitialPositionInStream.TRIM_HORIZON) or + * the tip of the stream (InitialPositionInStream.LATEST). + * @param storageLevel Storage level to use for storing the received objects + * + * @return ReceiverInputDStream[Array[Byte]] + */ + def createStream( + ssc: StreamingContext, + streamName: String, + endpointUrl: String, + checkpointInterval: Duration, + initialPositionInStream: InitialPositionInStream, + storageLevel: StorageLevel): ReceiverInputDStream[Array[Byte]] = { + ssc.receiverStream(new KinesisReceiver(ssc.sc.appName, streamName, endpointUrl, + checkpointInterval, initialPositionInStream, storageLevel)) + } + + /** + * Create a Java-friendly InputDStream that pulls messages from a Kinesis stream. + * + * @param jssc Java StreamingContext object + * @param ssc StreamingContext object + * @param streamName Kinesis stream name + * @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com) + * @param checkpointInterval Checkpoint interval for Kinesis checkpointing. + * See the Kinesis Spark Streaming documentation for more + * details on the different types of checkpoints. + * @param initialPositionInStream In the absence of Kinesis checkpoint info, this is the + * worker's initial starting position in the stream. + * The values are either the beginning of the stream + * per Kinesis' limit of 24 hours + * (InitialPositionInStream.TRIM_HORIZON) or + * the tip of the stream (InitialPositionInStream.LATEST). + * @param storageLevel Storage level to use for storing the received objects + * + * @return JavaReceiverInputDStream[Array[Byte]] + */ + def createStream( + jssc: JavaStreamingContext, + streamName: String, + endpointUrl: String, + checkpointInterval: Duration, + initialPositionInStream: InitialPositionInStream, + storageLevel: StorageLevel): JavaReceiverInputDStream[Array[Byte]] = { + jssc.receiverStream(new KinesisReceiver(jssc.ssc.sc.appName, streamName, + endpointUrl, checkpointInterval, initialPositionInStream, storageLevel)) + } +} diff --git a/extras/kinesis-asl/src/test/java/org/apache/spark/streaming/kinesis/JavaKinesisStreamSuite.java b/extras/kinesis-asl/src/test/java/org/apache/spark/streaming/kinesis/JavaKinesisStreamSuite.java new file mode 100644 index 0000000000000..87954a31f60ce --- /dev/null +++ b/extras/kinesis-asl/src/test/java/org/apache/spark/streaming/kinesis/JavaKinesisStreamSuite.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kinesis; + +import org.apache.spark.storage.StorageLevel; +import org.apache.spark.streaming.Duration; +import org.apache.spark.streaming.LocalJavaStreamingContext; +import org.apache.spark.streaming.api.java.JavaDStream; +import org.junit.Test; + +import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream; + +/** + * Demonstrate the use of the KinesisUtils Java API + */ +public class JavaKinesisStreamSuite extends LocalJavaStreamingContext { + @Test + public void testKinesisStream() { + // Tests the API, does not actually test data receiving + JavaDStream kinesisStream = KinesisUtils.createStream(ssc, "mySparkStream", + "https://kinesis.us-west-2.amazonaws.com", new Duration(2000), + InitialPositionInStream.LATEST, StorageLevel.MEMORY_AND_DISK_2()); + + ssc.stop(); + } +} diff --git a/extras/kinesis-asl/src/test/resources/log4j.properties b/extras/kinesis-asl/src/test/resources/log4j.properties new file mode 100644 index 0000000000000..e01e049595475 --- /dev/null +++ b/extras/kinesis-asl/src/test/resources/log4j.properties @@ -0,0 +1,26 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +log4j.rootCategory=INFO, file +# log4j.appender.file=org.apache.log4j.FileAppender +log4j.appender.file=org.apache.log4j.FileAppender +log4j.appender.file.append=false +log4j.appender.file.file=target/unit-tests.log +log4j.appender.file.layout=org.apache.log4j.PatternLayout +log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %p %c{1}: %m%n + +# Ignore messages below warning level from Jetty, because it's a bit verbose +log4j.logger.org.eclipse.jetty=WARN diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala new file mode 100644 index 0000000000000..41dbd64c2b1fa --- /dev/null +++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala @@ -0,0 +1,275 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.streaming.kinesis + +import java.nio.ByteBuffer + +import scala.collection.JavaConversions.seqAsJavaList + +import org.apache.spark.annotation.Experimental +import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming.Milliseconds +import org.apache.spark.streaming.Seconds +import org.apache.spark.streaming.StreamingContext +import org.apache.spark.streaming.TestSuiteBase +import org.apache.spark.streaming.util.Clock +import org.apache.spark.streaming.util.ManualClock +import org.scalatest.BeforeAndAfter +import org.scalatest.Matchers +import org.scalatest.mock.EasyMockSugar + +import com.amazonaws.services.kinesis.clientlibrary.exceptions.InvalidStateException +import com.amazonaws.services.kinesis.clientlibrary.exceptions.KinesisClientLibDependencyException +import com.amazonaws.services.kinesis.clientlibrary.exceptions.ShutdownException +import com.amazonaws.services.kinesis.clientlibrary.exceptions.ThrottlingException +import com.amazonaws.services.kinesis.clientlibrary.interfaces.IRecordProcessorCheckpointer +import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream +import com.amazonaws.services.kinesis.clientlibrary.types.ShutdownReason +import com.amazonaws.services.kinesis.model.Record + +/** + * Suite of Kinesis streaming receiver tests focusing mostly on the KinesisRecordProcessor + */ +class KinesisReceiverSuite extends TestSuiteBase with Matchers with BeforeAndAfter + with EasyMockSugar { + + val app = "TestKinesisReceiver" + val stream = "mySparkStream" + val endpoint = "endpoint-url" + val workerId = "dummyWorkerId" + val shardId = "dummyShardId" + + val record1 = new Record() + record1.setData(ByteBuffer.wrap("Spark In Action".getBytes())) + val record2 = new Record() + record2.setData(ByteBuffer.wrap("Learning Spark".getBytes())) + val batch = List[Record](record1, record2) + + var receiverMock: KinesisReceiver = _ + var checkpointerMock: IRecordProcessorCheckpointer = _ + var checkpointClockMock: ManualClock = _ + var checkpointStateMock: KinesisCheckpointState = _ + var currentClockMock: Clock = _ + + override def beforeFunction() = { + receiverMock = mock[KinesisReceiver] + checkpointerMock = mock[IRecordProcessorCheckpointer] + checkpointClockMock = mock[ManualClock] + checkpointStateMock = mock[KinesisCheckpointState] + currentClockMock = mock[Clock] + } + + test("kinesis utils api") { + val ssc = new StreamingContext(master, framework, batchDuration) + // Tests the API, does not actually test data receiving + val kinesisStream = KinesisUtils.createStream(ssc, "mySparkStream", + "https://kinesis.us-west-2.amazonaws.com", Seconds(2), + InitialPositionInStream.LATEST, StorageLevel.MEMORY_AND_DISK_2); + ssc.stop() + } + + test("process records including store and checkpoint") { + val expectedCheckpointIntervalMillis = 10 + expecting { + receiverMock.isStopped().andReturn(false).once() + receiverMock.store(record1.getData().array()).once() + receiverMock.store(record2.getData().array()).once() + checkpointStateMock.shouldCheckpoint().andReturn(true).once() + checkpointerMock.checkpoint().once() + checkpointStateMock.advanceCheckpoint().once() + } + whenExecuting(receiverMock, checkpointerMock, checkpointStateMock) { + val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId, + checkpointStateMock) + recordProcessor.processRecords(batch, checkpointerMock) + } + } + + test("shouldn't store and checkpoint when receiver is stopped") { + expecting { + receiverMock.isStopped().andReturn(true).once() + } + whenExecuting(receiverMock, checkpointerMock, checkpointStateMock) { + val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId, + checkpointStateMock) + recordProcessor.processRecords(batch, checkpointerMock) + } + } + + test("shouldn't checkpoint when exception occurs during store") { + expecting { + receiverMock.isStopped().andReturn(false).once() + receiverMock.store(record1.getData().array()).andThrow(new RuntimeException()).once() + } + whenExecuting(receiverMock, checkpointerMock, checkpointStateMock) { + intercept[RuntimeException] { + val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId, + checkpointStateMock) + recordProcessor.processRecords(batch, checkpointerMock) + } + } + } + + test("should set checkpoint time to currentTime + checkpoint interval upon instantiation") { + expecting { + currentClockMock.currentTime().andReturn(0).once() + } + whenExecuting(currentClockMock) { + val checkpointIntervalMillis = 10 + val checkpointState = new KinesisCheckpointState(Milliseconds(checkpointIntervalMillis), currentClockMock) + assert(checkpointState.checkpointClock.currentTime() == checkpointIntervalMillis) + } + } + + test("should checkpoint if we have exceeded the checkpoint interval") { + expecting { + currentClockMock.currentTime().andReturn(0).once() + } + whenExecuting(currentClockMock) { + val checkpointState = new KinesisCheckpointState(Milliseconds(Long.MinValue), currentClockMock) + assert(checkpointState.shouldCheckpoint()) + } + } + + test("shouldn't checkpoint if we have not exceeded the checkpoint interval") { + expecting { + currentClockMock.currentTime().andReturn(0).once() + } + whenExecuting(currentClockMock) { + val checkpointState = new KinesisCheckpointState(Milliseconds(Long.MaxValue), currentClockMock) + assert(!checkpointState.shouldCheckpoint()) + } + } + + test("should add to time when advancing checkpoint") { + expecting { + currentClockMock.currentTime().andReturn(0).once() + } + whenExecuting(currentClockMock) { + val checkpointIntervalMillis = 10 + val checkpointState = new KinesisCheckpointState(Milliseconds(checkpointIntervalMillis), currentClockMock) + assert(checkpointState.checkpointClock.currentTime() == checkpointIntervalMillis) + checkpointState.advanceCheckpoint() + assert(checkpointState.checkpointClock.currentTime() == (2 * checkpointIntervalMillis)) + } + } + + test("shutdown should checkpoint if the reason is TERMINATE") { + expecting { + checkpointerMock.checkpoint().once() + } + whenExecuting(checkpointerMock, checkpointStateMock) { + val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId, + checkpointStateMock) + val reason = ShutdownReason.TERMINATE + recordProcessor.shutdown(checkpointerMock, reason) + } + } + + test("shutdown should not checkpoint if the reason is something other than TERMINATE") { + expecting { + } + whenExecuting(checkpointerMock, checkpointStateMock) { + val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId, + checkpointStateMock) + recordProcessor.shutdown(checkpointerMock, ShutdownReason.ZOMBIE) + recordProcessor.shutdown(checkpointerMock, null) + } + } + + test("retry success on first attempt") { + val expectedIsStopped = false + expecting { + receiverMock.isStopped().andReturn(expectedIsStopped).once() + } + whenExecuting(receiverMock) { + val actualVal = KinesisRecordProcessor.retryRandom(receiverMock.isStopped(), 2, 100) + assert(actualVal == expectedIsStopped) + } + } + + test("retry success on second attempt after a Kinesis throttling exception") { + val expectedIsStopped = false + expecting { + receiverMock.isStopped().andThrow(new ThrottlingException("error message")) + .andReturn(expectedIsStopped).once() + } + whenExecuting(receiverMock) { + val actualVal = KinesisRecordProcessor.retryRandom(receiverMock.isStopped(), 2, 100) + assert(actualVal == expectedIsStopped) + } + } + + test("retry success on second attempt after a Kinesis dependency exception") { + val expectedIsStopped = false + expecting { + receiverMock.isStopped().andThrow(new KinesisClientLibDependencyException("error message")) + .andReturn(expectedIsStopped).once() + } + whenExecuting(receiverMock) { + val actualVal = KinesisRecordProcessor.retryRandom(receiverMock.isStopped(), 2, 100) + assert(actualVal == expectedIsStopped) + } + } + + test("retry failed after a shutdown exception") { + expecting { + checkpointerMock.checkpoint().andThrow(new ShutdownException("error message")).once() + } + whenExecuting(checkpointerMock) { + intercept[ShutdownException] { + KinesisRecordProcessor.retryRandom(checkpointerMock.checkpoint(), 2, 100) + } + } + } + + test("retry failed after an invalid state exception") { + expecting { + checkpointerMock.checkpoint().andThrow(new InvalidStateException("error message")).once() + } + whenExecuting(checkpointerMock) { + intercept[InvalidStateException] { + KinesisRecordProcessor.retryRandom(checkpointerMock.checkpoint(), 2, 100) + } + } + } + + test("retry failed after unexpected exception") { + expecting { + checkpointerMock.checkpoint().andThrow(new RuntimeException("error message")).once() + } + whenExecuting(checkpointerMock) { + intercept[RuntimeException] { + KinesisRecordProcessor.retryRandom(checkpointerMock.checkpoint(), 2, 100) + } + } + } + + test("retry failed after exhausing all retries") { + val expectedErrorMessage = "final try error message" + expecting { + checkpointerMock.checkpoint().andThrow(new ThrottlingException("error message")) + .andThrow(new ThrottlingException(expectedErrorMessage)).once() + } + whenExecuting(checkpointerMock) { + val exception = intercept[RuntimeException] { + KinesisRecordProcessor.retryRandom(checkpointerMock.checkpoint(), 2, 100) + } + exception.getMessage().shouldBe(expectedErrorMessage) + } + } +} diff --git a/extras/spark-ganglia-lgpl/src/main/scala/org/apache/spark/metrics/sink/GangliaSink.scala b/extras/spark-ganglia-lgpl/src/main/scala/org/apache/spark/metrics/sink/GangliaSink.scala index d03d7774e8c80..3b1880e143513 100644 --- a/extras/spark-ganglia-lgpl/src/main/scala/org/apache/spark/metrics/sink/GangliaSink.scala +++ b/extras/spark-ganglia-lgpl/src/main/scala/org/apache/spark/metrics/sink/GangliaSink.scala @@ -82,5 +82,9 @@ class GangliaSink(val property: Properties, val registry: MetricRegistry, override def stop() { reporter.stop() } + + override def report() { + reporter.report() + } } diff --git a/graphx/pom.xml b/graphx/pom.xml index 7e3bcf29dcfbc..6dd52fc618b1e 100644 --- a/graphx/pom.xml +++ b/graphx/pom.xml @@ -28,7 +28,7 @@ org.apache.spark spark-graphx_2.10 - graphx + graphx jar Spark Project GraphX diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/MessageToPartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/MessageToPartition.scala index 5318b8da6412a..714f3b81c9dad 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/MessageToPartition.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/MessageToPartition.scala @@ -28,7 +28,7 @@ import org.apache.spark.rdd.{ShuffledRDD, RDD} private[graphx] class VertexRDDFunctions[VD: ClassTag](self: RDD[(VertexId, VD)]) { def copartitionWithVertices(partitioner: Partitioner): RDD[(VertexId, VD)] = { - val rdd = new ShuffledRDD[VertexId, VD, VD, (VertexId, VD)](self, partitioner) + val rdd = new ShuffledRDD[VertexId, VD, VD](self, partitioner) // Set a custom serializer if the data is of int or double type. if (classTag[VD] == ClassTag.Int) { diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala index a565d3b28bf52..b27485953f719 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala @@ -33,7 +33,7 @@ private[graphx] class RoutingTableMessageRDDFunctions(self: RDD[RoutingTableMessage]) { /** Copartition an `RDD[RoutingTableMessage]` with the vertex RDD with the given `partitioner`. */ def copartitionWithVertices(partitioner: Partitioner): RDD[RoutingTableMessage] = { - new ShuffledRDD[VertexId, Int, Int, RoutingTableMessage]( + new ShuffledRDD[VertexId, Int, Int]( self, partitioner).setSerializer(new RoutingTableMessageSerializer) } } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala b/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala index 635514f09ece0..60149548ab852 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala @@ -100,8 +100,10 @@ object GraphGenerators { */ private def sampleLogNormal(mu: Double, sigma: Double, maxVal: Int): Int = { val rand = new Random() - val m = math.exp(mu + (sigma * sigma) / 2.0) - val s = math.sqrt((math.exp(sigma*sigma) - 1) * math.exp(2*mu + sigma*sigma)) + val sigmaSq = sigma * sigma + val m = math.exp(mu + sigmaSq / 2.0) + // expm1 is exp(m)-1 with better accuracy for tiny m + val s = math.sqrt(math.expm1(sigmaSq) * math.exp(2*mu + sigmaSq)) // Z ~ N(0, 1) var X: Double = maxVal diff --git a/make-distribution.sh b/make-distribution.sh index c08093f46b61f..f7a6a9d838bb6 100755 --- a/make-distribution.sh +++ b/make-distribution.sh @@ -128,7 +128,7 @@ if [[ ! "$JAVA_VERSION" =~ "1.6" && -z "$SKIP_JAVA_TEST" ]]; then if [[ ! $REPLY =~ ^[Yy]$ ]]; then echo "Okay, exiting." exit 1 - fi + fi fi if [ "$NAME" == "none" ]; then @@ -150,7 +150,7 @@ else fi # Build uber fat JAR -cd $FWDIR +cd "$FWDIR" export MAVEN_OPTS="-Xmx2g -XX:MaxPermSize=512M -XX:ReservedCodeCacheSize=512m" @@ -168,22 +168,22 @@ mkdir -p "$DISTDIR/lib" echo "Spark $VERSION$GITREVSTRING built for Hadoop $SPARK_HADOOP_VERSION" > "$DISTDIR/RELEASE" # Copy jars -cp $FWDIR/assembly/target/scala*/*assembly*hadoop*.jar "$DISTDIR/lib/" -cp $FWDIR/examples/target/scala*/spark-examples*.jar "$DISTDIR/lib/" +cp "$FWDIR"/assembly/target/scala*/*assembly*hadoop*.jar "$DISTDIR/lib/" +cp "$FWDIR"/examples/target/scala*/spark-examples*.jar "$DISTDIR/lib/" # Copy example sources (needed for python and SQL) mkdir -p "$DISTDIR/examples/src/main" -cp -r $FWDIR/examples/src/main "$DISTDIR/examples/src/" +cp -r "$FWDIR"/examples/src/main "$DISTDIR/examples/src/" if [ "$SPARK_HIVE" == "true" ]; then - cp $FWDIR/lib_managed/jars/datanucleus*.jar "$DISTDIR/lib/" + cp "$FWDIR"/lib_managed/jars/datanucleus*.jar "$DISTDIR/lib/" fi # Copy license and ASF files cp "$FWDIR/LICENSE" "$DISTDIR" cp "$FWDIR/NOTICE" "$DISTDIR" -if [ -e $FWDIR/CHANGES.txt ]; then +if [ -e "$FWDIR"/CHANGES.txt ]; then cp "$FWDIR/CHANGES.txt" "$DISTDIR" fi @@ -199,7 +199,7 @@ cp -r "$FWDIR/ec2" "$DISTDIR" # Download and copy in tachyon, if requested if [ "$SPARK_TACHYON" == "true" ]; then - TACHYON_VERSION="0.4.1" + TACHYON_VERSION="0.5.0" TACHYON_URL="https://github.com/amplab/tachyon/releases/download/v${TACHYON_VERSION}/tachyon-${TACHYON_VERSION}-bin.tar.gz" TMPD=`mktemp -d 2>/dev/null || mktemp -d -t 'disttmp'` diff --git a/mllib/pom.xml b/mllib/pom.xml index 92b07e2357db1..9a33bd1cf6ad1 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -28,7 +28,7 @@ org.apache.spark spark-mllib_2.10 - mllib + mllib jar Spark Project ML Library @@ -40,6 +40,11 @@ spark-core_${scala.binary.version} ${project.version} + + org.apache.spark + spark-streaming_${scala.binary.version} + ${project.version} + org.eclipse.jetty jetty-server @@ -60,6 +65,10 @@ junit junit + + org.apache.commons + commons-math3 + @@ -72,6 +81,11 @@ scalacheck_${scala.binary.version} test + + junit + junit + test + com.novocode junit-interface diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 954621ee8b933..fd0b9556c7d54 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -19,15 +19,28 @@ package org.apache.spark.mllib.api.python import java.nio.{ByteBuffer, ByteOrder} +import scala.collection.JavaConverters._ + import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.api.java.{JavaSparkContext, JavaRDD} +import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} import org.apache.spark.mllib.classification._ import org.apache.spark.mllib.clustering._ import org.apache.spark.mllib.linalg.{SparseVector, Vector, Vectors} +import org.apache.spark.mllib.optimization._ +import org.apache.spark.mllib.linalg.{Matrix, SparseVector, Vector, Vectors} +import org.apache.spark.mllib.random.{RandomRDDGenerators => RG} import org.apache.spark.mllib.recommendation._ import org.apache.spark.mllib.regression._ +import org.apache.spark.mllib.tree.configuration.Algo._ +import org.apache.spark.mllib.tree.configuration.Strategy +import org.apache.spark.mllib.tree.DecisionTree +import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Impurity, Variance} +import org.apache.spark.mllib.tree.model.DecisionTreeModel +import org.apache.spark.mllib.stat.Statistics +import org.apache.spark.mllib.stat.correlation.CorrelationNames import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD +import org.apache.spark.util.Utils /** * :: DeveloperApi :: @@ -225,7 +238,7 @@ class PythonMLLibAPI extends Serializable { jsc: JavaSparkContext, path: String, minPartitions: Int): JavaRDD[Array[Byte]] = - MLUtils.loadLabeledPoints(jsc.sc, path, minPartitions).map(serializeLabeledPoint).toJavaRDD() + MLUtils.loadLabeledPoints(jsc.sc, path, minPartitions).map(serializeLabeledPoint) private def trainRegressionModel( trainFunc: (RDD[LabeledPoint], Vector) => GeneralizedLinearModel, @@ -248,15 +261,28 @@ class PythonMLLibAPI extends Serializable { numIterations: Int, stepSize: Double, miniBatchFraction: Double, - initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = { + initialWeightsBA: Array[Byte], + regParam: Double, + regType: String, + intercept: Boolean): java.util.List[java.lang.Object] = { + val lrAlg = new LinearRegressionWithSGD() + lrAlg.setIntercept(intercept) + lrAlg.optimizer + .setNumIterations(numIterations) + .setRegParam(regParam) + .setStepSize(stepSize) + .setMiniBatchFraction(miniBatchFraction) + if (regType == "l2") { + lrAlg.optimizer.setUpdater(new SquaredL2Updater) + } else if (regType == "l1") { + lrAlg.optimizer.setUpdater(new L1Updater) + } else if (regType != "none") { + throw new java.lang.IllegalArgumentException("Invalid value for 'regType' parameter." + + " Can only be initialized using the following string values: [l1, l2, none].") + } trainRegressionModel( (data, initialWeights) => - LinearRegressionWithSGD.train( - data, - numIterations, - stepSize, - miniBatchFraction, - initialWeights), + lrAlg.run(data, initialWeights), dataBytesJRDD, initialWeightsBA) } @@ -316,16 +342,27 @@ class PythonMLLibAPI extends Serializable { stepSize: Double, regParam: Double, miniBatchFraction: Double, - initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = { + initialWeightsBA: Array[Byte], + regType: String, + intercept: Boolean): java.util.List[java.lang.Object] = { + val SVMAlg = new SVMWithSGD() + SVMAlg.setIntercept(intercept) + SVMAlg.optimizer + .setNumIterations(numIterations) + .setRegParam(regParam) + .setStepSize(stepSize) + .setMiniBatchFraction(miniBatchFraction) + if (regType == "l2") { + SVMAlg.optimizer.setUpdater(new SquaredL2Updater) + } else if (regType == "l1") { + SVMAlg.optimizer.setUpdater(new L1Updater) + } else if (regType != "none") { + throw new java.lang.IllegalArgumentException("Invalid value for 'regType' parameter." + + " Can only be initialized using the following string values: [l1, l2, none].") + } trainRegressionModel( (data, initialWeights) => - SVMWithSGD.train( - data, - numIterations, - stepSize, - regParam, - miniBatchFraction, - initialWeights), + SVMAlg.run(data, initialWeights), dataBytesJRDD, initialWeightsBA) } @@ -338,15 +375,28 @@ class PythonMLLibAPI extends Serializable { numIterations: Int, stepSize: Double, miniBatchFraction: Double, - initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = { + initialWeightsBA: Array[Byte], + regParam: Double, + regType: String, + intercept: Boolean): java.util.List[java.lang.Object] = { + val LogRegAlg = new LogisticRegressionWithSGD() + LogRegAlg.setIntercept(intercept) + LogRegAlg.optimizer + .setNumIterations(numIterations) + .setRegParam(regParam) + .setStepSize(stepSize) + .setMiniBatchFraction(miniBatchFraction) + if (regType == "l2") { + LogRegAlg.optimizer.setUpdater(new SquaredL2Updater) + } else if (regType == "l1") { + LogRegAlg.optimizer.setUpdater(new L1Updater) + } else if (regType != "none") { + throw new java.lang.IllegalArgumentException("Invalid value for 'regType' parameter." + + " Can only be initialized using the following string values: [l1, l2, none].") + } trainRegressionModel( (data, initialWeights) => - LogisticRegressionWithSGD.train( - data, - numIterations, - stepSize, - miniBatchFraction, - initialWeights), + LogRegAlg.run(data, initialWeights), dataBytesJRDD, initialWeightsBA) } @@ -453,4 +503,201 @@ class PythonMLLibAPI extends Serializable { val ratings = ratingsBytesJRDD.rdd.map(unpackRating) ALS.trainImplicit(ratings, rank, iterations, lambda, blocks, alpha) } + + /** + * Java stub for Python mllib DecisionTree.train(). + * This stub returns a handle to the Java object instead of the content of the Java object. + * Extra care needs to be taken in the Python code to ensure it gets freed on exit; + * see the Py4J documentation. + * @param dataBytesJRDD Training data + * @param categoricalFeaturesInfoJMap Categorical features info, as Java map + */ + def trainDecisionTreeModel( + dataBytesJRDD: JavaRDD[Array[Byte]], + algoStr: String, + numClasses: Int, + categoricalFeaturesInfoJMap: java.util.Map[Int, Int], + impurityStr: String, + maxDepth: Int, + maxBins: Int): DecisionTreeModel = { + + val data = dataBytesJRDD.rdd.map(deserializeLabeledPoint) + + val algo: Algo = algoStr match { + case "classification" => Classification + case "regression" => Regression + case _ => throw new IllegalArgumentException(s"Bad algoStr parameter: $algoStr") + } + val impurity: Impurity = impurityStr match { + case "gini" => Gini + case "entropy" => Entropy + case "variance" => Variance + case _ => throw new IllegalArgumentException(s"Bad impurityStr parameter: $impurityStr") + } + + val strategy = new Strategy( + algo = algo, + impurity = impurity, + maxDepth = maxDepth, + numClassesForClassification = numClasses, + maxBins = maxBins, + categoricalFeaturesInfo = categoricalFeaturesInfoJMap.asScala.toMap) + + DecisionTree.train(data, strategy) + } + + /** + * Predict the label of the given data point. + * This is a Java stub for python DecisionTreeModel.predict() + * + * @param featuresBytes Serialized feature vector for data point + * @return predicted label + */ + def predictDecisionTreeModel( + model: DecisionTreeModel, + featuresBytes: Array[Byte]): Double = { + val features: Vector = deserializeDoubleVector(featuresBytes) + model.predict(features) + } + + /** + * Predict the labels of the given data points. + * This is a Java stub for python DecisionTreeModel.predict() + * + * @param dataJRDD A JavaRDD with serialized feature vectors + * @return JavaRDD of serialized predictions + */ + def predictDecisionTreeModel( + model: DecisionTreeModel, + dataJRDD: JavaRDD[Array[Byte]]): JavaRDD[Array[Byte]] = { + val data = dataJRDD.rdd.map(xBytes => deserializeDoubleVector(xBytes)) + model.predict(data).map(serializeDouble) + } + + /** + * Java stub for mllib Statistics.corr(X: RDD[Vector], method: String). + * Returns the correlation matrix serialized into a byte array understood by deserializers in + * pyspark. + */ + def corr(X: JavaRDD[Array[Byte]], method: String): Array[Byte] = { + val inputMatrix = X.rdd.map(deserializeDoubleVector(_)) + val result = Statistics.corr(inputMatrix, getCorrNameOrDefault(method)) + serializeDoubleMatrix(to2dArray(result)) + } + + /** + * Java stub for mllib Statistics.corr(x: RDD[Double], y: RDD[Double], method: String). + */ + def corr(x: JavaRDD[Array[Byte]], y: JavaRDD[Array[Byte]], method: String): Double = { + val xDeser = x.rdd.map(deserializeDouble(_)) + val yDeser = y.rdd.map(deserializeDouble(_)) + Statistics.corr(xDeser, yDeser, getCorrNameOrDefault(method)) + } + + // used by the corr methods to retrieve the name of the correlation method passed in via pyspark + private def getCorrNameOrDefault(method: String) = { + if (method == null) CorrelationNames.defaultCorrName else method + } + + // Reformat a Matrix into Array[Array[Double]] for serialization + private[python] def to2dArray(matrix: Matrix): Array[Array[Double]] = { + val values = matrix.toArray + Array.tabulate(matrix.numRows, matrix.numCols)((i, j) => values(i + j * matrix.numRows)) + } + + // Used by the *RDD methods to get default seed if not passed in from pyspark + private def getSeedOrDefault(seed: java.lang.Long): Long = { + if (seed == null) Utils.random.nextLong else seed + } + + // Used by *RDD methods to get default numPartitions if not passed in from pyspark + private def getNumPartitionsOrDefault(numPartitions: java.lang.Integer, + jsc: JavaSparkContext): Int = { + if (numPartitions == null) { + jsc.sc.defaultParallelism + } else { + numPartitions + } + } + + // Note: for the following methods, numPartitions and seed are boxed to allow nulls to be passed + // in for either argument from pyspark + + /** + * Java stub for Python mllib RandomRDDGenerators.uniformRDD() + */ + def uniformRDD(jsc: JavaSparkContext, + size: Long, + numPartitions: java.lang.Integer, + seed: java.lang.Long): JavaRDD[Array[Byte]] = { + val parts = getNumPartitionsOrDefault(numPartitions, jsc) + val s = getSeedOrDefault(seed) + RG.uniformRDD(jsc.sc, size, parts, s).map(serializeDouble) + } + + /** + * Java stub for Python mllib RandomRDDGenerators.normalRDD() + */ + def normalRDD(jsc: JavaSparkContext, + size: Long, + numPartitions: java.lang.Integer, + seed: java.lang.Long): JavaRDD[Array[Byte]] = { + val parts = getNumPartitionsOrDefault(numPartitions, jsc) + val s = getSeedOrDefault(seed) + RG.normalRDD(jsc.sc, size, parts, s).map(serializeDouble) + } + + /** + * Java stub for Python mllib RandomRDDGenerators.poissonRDD() + */ + def poissonRDD(jsc: JavaSparkContext, + mean: Double, + size: Long, + numPartitions: java.lang.Integer, + seed: java.lang.Long): JavaRDD[Array[Byte]] = { + val parts = getNumPartitionsOrDefault(numPartitions, jsc) + val s = getSeedOrDefault(seed) + RG.poissonRDD(jsc.sc, mean, size, parts, s).map(serializeDouble) + } + + /** + * Java stub for Python mllib RandomRDDGenerators.uniformVectorRDD() + */ + def uniformVectorRDD(jsc: JavaSparkContext, + numRows: Long, + numCols: Int, + numPartitions: java.lang.Integer, + seed: java.lang.Long): JavaRDD[Array[Byte]] = { + val parts = getNumPartitionsOrDefault(numPartitions, jsc) + val s = getSeedOrDefault(seed) + RG.uniformVectorRDD(jsc.sc, numRows, numCols, parts, s).map(serializeDoubleVector) + } + + /** + * Java stub for Python mllib RandomRDDGenerators.normalVectorRDD() + */ + def normalVectorRDD(jsc: JavaSparkContext, + numRows: Long, + numCols: Int, + numPartitions: java.lang.Integer, + seed: java.lang.Long): JavaRDD[Array[Byte]] = { + val parts = getNumPartitionsOrDefault(numPartitions, jsc) + val s = getSeedOrDefault(seed) + RG.normalVectorRDD(jsc.sc, numRows, numCols, parts, s).map(serializeDoubleVector) + } + + /** + * Java stub for Python mllib RandomRDDGenerators.poissonVectorRDD() + */ + def poissonVectorRDD(jsc: JavaSparkContext, + mean: Double, + numRows: Long, + numCols: Int, + numPartitions: java.lang.Integer, + seed: java.lang.Long): JavaRDD[Array[Byte]] = { + val parts = getNumPartitionsOrDefault(numPartitions, jsc) + val s = getSeedOrDefault(seed) + RG.poissonVectorRDD(jsc.sc, mean, numRows, numCols, parts, s).map(serializeDoubleVector) + } + } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index db425d866bbad..fce8fe29f6e40 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -52,13 +52,13 @@ class KMeans private ( def this() = this(2, 20, 1, KMeans.K_MEANS_PARALLEL, 5, 1e-4) /** Set the number of clusters to create (k). Default: 2. */ - def setK(k: Int): KMeans = { + def setK(k: Int): this.type = { this.k = k this } /** Set maximum number of iterations to run. Default: 20. */ - def setMaxIterations(maxIterations: Int): KMeans = { + def setMaxIterations(maxIterations: Int): this.type = { this.maxIterations = maxIterations this } @@ -68,7 +68,7 @@ class KMeans private ( * initial cluster centers, or "k-means||" to use a parallel variant of k-means++ * (Bahmani et al., Scalable K-Means++, VLDB 2012). Default: k-means||. */ - def setInitializationMode(initializationMode: String): KMeans = { + def setInitializationMode(initializationMode: String): this.type = { if (initializationMode != KMeans.RANDOM && initializationMode != KMeans.K_MEANS_PARALLEL) { throw new IllegalArgumentException("Invalid initialization mode: " + initializationMode) } @@ -83,7 +83,7 @@ class KMeans private ( * return the best clustering found over any run. Default: 1. */ @Experimental - def setRuns(runs: Int): KMeans = { + def setRuns(runs: Int): this.type = { if (runs <= 0) { throw new IllegalArgumentException("Number of runs must be positive") } @@ -95,7 +95,7 @@ class KMeans private ( * Set the number of steps for the k-means|| initialization mode. This is an advanced * setting -- the default of 5 is almost always enough. Default: 5. */ - def setInitializationSteps(initializationSteps: Int): KMeans = { + def setInitializationSteps(initializationSteps: Int): this.type = { if (initializationSteps <= 0) { throw new IllegalArgumentException("Number of initialization steps must be positive") } @@ -107,7 +107,7 @@ class KMeans private ( * Set the distance threshold within which we've consider centers to have converged. * If all centers move less than this Euclidean distance, we stop iterating one run. */ - def setEpsilon(epsilon: Double): KMeans = { + def setEpsilon(epsilon: Double): this.type = { this.epsilon = epsilon this } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala new file mode 100644 index 0000000000000..0f6d5809e098f --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.feature + +import java.lang.{Iterable => JavaIterable} + +import scala.collection.JavaConverters._ +import scala.collection.mutable + +import org.apache.spark.annotation.Experimental +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.rdd.RDD +import org.apache.spark.util.Utils + +/** + * :: Experimental :: + * Maps a sequence of terms to their term frequencies using the hashing trick. + * + * @param numFeatures number of features (default: 1000000) + */ +@Experimental +class HashingTF(val numFeatures: Int) extends Serializable { + + def this() = this(1000000) + + /** + * Returns the index of the input term. + */ + def indexOf(term: Any): Int = Utils.nonNegativeMod(term.##, numFeatures) + + /** + * Transforms the input document into a sparse term frequency vector. + */ + def transform(document: Iterable[_]): Vector = { + val termFrequencies = mutable.HashMap.empty[Int, Double] + document.foreach { term => + val i = indexOf(term) + termFrequencies.put(i, termFrequencies.getOrElse(i, 0.0) + 1.0) + } + Vectors.sparse(numFeatures, termFrequencies.toSeq) + } + + /** + * Transforms the input document into a sparse term frequency vector (Java version). + */ + def transform(document: JavaIterable[_]): Vector = { + transform(document.asScala) + } + + /** + * Transforms the input document to term frequency vectors. + */ + def transform[D <: Iterable[_]](dataset: RDD[D]): RDD[Vector] = { + dataset.map(this.transform) + } + + /** + * Transforms the input document to term frequency vectors (Java version). + */ + def transform[D <: JavaIterable[_]](dataset: JavaRDD[D]): JavaRDD[Vector] = { + dataset.rdd.map(this.transform).toJavaRDD() + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala new file mode 100644 index 0000000000000..7ed611a857acc --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala @@ -0,0 +1,194 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.feature + +import breeze.linalg.{DenseVector => BDV} + +import org.apache.spark.annotation.Experimental +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} +import org.apache.spark.mllib.rdd.RDDFunctions._ +import org.apache.spark.rdd.RDD + +/** + * :: Experimental :: + * Inverse document frequency (IDF). + * The standard formulation is used: `idf = log((m + 1) / (d(t) + 1))`, where `m` is the total + * number of documents and `d(t)` is the number of documents that contain term `t`. + */ +@Experimental +class IDF { + + // TODO: Allow different IDF formulations. + + private var brzIdf: BDV[Double] = _ + + /** + * Computes the inverse document frequency. + * @param dataset an RDD of term frequency vectors + */ + def fit(dataset: RDD[Vector]): this.type = { + brzIdf = dataset.treeAggregate(new IDF.DocumentFrequencyAggregator)( + seqOp = (df, v) => df.add(v), + combOp = (df1, df2) => df1.merge(df2) + ).idf() + this + } + + /** + * Computes the inverse document frequency. + * @param dataset a JavaRDD of term frequency vectors + */ + def fit(dataset: JavaRDD[Vector]): this.type = { + fit(dataset.rdd) + } + + /** + * Transforms term frequency (TF) vectors to TF-IDF vectors. + * @param dataset an RDD of term frequency vectors + * @return an RDD of TF-IDF vectors + */ + def transform(dataset: RDD[Vector]): RDD[Vector] = { + if (!initialized) { + throw new IllegalStateException("Haven't learned IDF yet. Call fit first.") + } + val theIdf = brzIdf + val bcIdf = dataset.context.broadcast(theIdf) + dataset.mapPartitions { iter => + val thisIdf = bcIdf.value + iter.map { v => + val n = v.size + v match { + case sv: SparseVector => + val nnz = sv.indices.size + val newValues = new Array[Double](nnz) + var k = 0 + while (k < nnz) { + newValues(k) = sv.values(k) * thisIdf(sv.indices(k)) + k += 1 + } + Vectors.sparse(n, sv.indices, newValues) + case dv: DenseVector => + val newValues = new Array[Double](n) + var j = 0 + while (j < n) { + newValues(j) = dv.values(j) * thisIdf(j) + j += 1 + } + Vectors.dense(newValues) + case other => + throw new UnsupportedOperationException( + s"Only sparse and dense vectors are supported but got ${other.getClass}.") + } + } + } + } + + /** + * Transforms term frequency (TF) vectors to TF-IDF vectors (Java version). + * @param dataset a JavaRDD of term frequency vectors + * @return a JavaRDD of TF-IDF vectors + */ + def transform(dataset: JavaRDD[Vector]): JavaRDD[Vector] = { + transform(dataset.rdd).toJavaRDD() + } + + /** Returns the IDF vector. */ + def idf(): Vector = { + if (!initialized) { + throw new IllegalStateException("Haven't learned IDF yet. Call fit first.") + } + Vectors.fromBreeze(brzIdf) + } + + private def initialized: Boolean = brzIdf != null +} + +private object IDF { + + /** Document frequency aggregator. */ + class DocumentFrequencyAggregator extends Serializable { + + /** number of documents */ + private var m = 0L + /** document frequency vector */ + private var df: BDV[Long] = _ + + /** Adds a new document. */ + def add(doc: Vector): this.type = { + if (isEmpty) { + df = BDV.zeros(doc.size) + } + doc match { + case sv: SparseVector => + val nnz = sv.indices.size + var k = 0 + while (k < nnz) { + if (sv.values(k) > 0) { + df(sv.indices(k)) += 1L + } + k += 1 + } + case dv: DenseVector => + val n = dv.size + var j = 0 + while (j < n) { + if (dv.values(j) > 0.0) { + df(j) += 1L + } + j += 1 + } + case other => + throw new UnsupportedOperationException( + s"Only sparse and dense vectors are supported but got ${other.getClass}.") + } + m += 1L + this + } + + /** Merges another. */ + def merge(other: DocumentFrequencyAggregator): this.type = { + if (!other.isEmpty) { + m += other.m + if (df == null) { + df = other.df.copy + } else { + df += other.df + } + } + this + } + + private def isEmpty: Boolean = m == 0L + + /** Returns the current IDF vector. */ + def idf(): BDV[Double] = { + if (isEmpty) { + throw new IllegalStateException("Haven't seen any document yet.") + } + val n = df.length + val inv = BDV.zeros[Double](n) + var j = 0 + while (j < n) { + inv(j) = math.log((m + 1.0)/ (df(j) + 1.0)) + j += 1 + } + inv + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala new file mode 100644 index 0000000000000..ea9fd0a80d8e0 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.feature + +import breeze.linalg.{DenseVector => BDV, SparseVector => BSV} + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.mllib.linalg.{Vector, Vectors} + +/** + * :: DeveloperApi :: + * Normalizes samples individually to unit L^p^ norm + * + * For any 1 <= p < Double.PositiveInfinity, normalizes samples using + * sum(abs(vector).^p^)^(1/p)^ as norm. + * + * For p = Double.PositiveInfinity, max(abs(vector)) will be used as norm for normalization. + * + * @param p Normalization in L^p^ space, p = 2 by default. + */ +@DeveloperApi +class Normalizer(p: Double) extends VectorTransformer { + + def this() = this(2) + + require(p >= 1.0) + + /** + * Applies unit length normalization on a vector. + * + * @param vector vector to be normalized. + * @return normalized vector. If the norm of the input is zero, it will return the input vector. + */ + override def transform(vector: Vector): Vector = { + var norm = vector.toBreeze.norm(p) + + if (norm != 0.0) { + // For dense vector, we've to allocate new memory for new output vector. + // However, for sparse vector, the `index` array will not be changed, + // so we can re-use it to save memory. + vector.toBreeze match { + case dv: BDV[Double] => Vectors.fromBreeze(dv :/ norm) + case sv: BSV[Double] => + val output = new BSV[Double](sv.index, sv.data.clone(), sv.length) + var i = 0 + while (i < output.data.length) { + output.data(i) /= norm + i += 1 + } + Vectors.fromBreeze(output) + case v => throw new IllegalArgumentException("Do not support vector type " + v.getClass) + } + } else { + // Since the norm is zero, return the input vector object itself. + // Note that it's safe since we always assume that the data in RDD + // should be immutable. + vector + } + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala new file mode 100644 index 0000000000000..cc2d7579c2901 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.feature + +import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, Vector => BV} + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.rdd.RDDFunctions._ +import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer +import org.apache.spark.rdd.RDD + +/** + * :: DeveloperApi :: + * Standardizes features by removing the mean and scaling to unit variance using column summary + * statistics on the samples in the training set. + * + * @param withMean False by default. Centers the data with mean before scaling. It will build a + * dense output, so this does not work on sparse input and will raise an exception. + * @param withStd True by default. Scales the data to unit standard deviation. + */ +@DeveloperApi +class StandardScaler(withMean: Boolean, withStd: Boolean) extends VectorTransformer { + + def this() = this(false, true) + + require(withMean || withStd, s"withMean and withStd both equal to false. Doing nothing.") + + private var mean: BV[Double] = _ + private var factor: BV[Double] = _ + + /** + * Computes the mean and variance and stores as a model to be used for later scaling. + * + * @param data The data used to compute the mean and variance to build the transformation model. + * @return This StandardScalar object. + */ + def fit(data: RDD[Vector]): this.type = { + val summary = data.treeAggregate(new MultivariateOnlineSummarizer)( + (aggregator, data) => aggregator.add(data), + (aggregator1, aggregator2) => aggregator1.merge(aggregator2)) + + mean = summary.mean.toBreeze + factor = summary.variance.toBreeze + require(mean.length == factor.length) + + var i = 0 + while (i < factor.length) { + factor(i) = if (factor(i) != 0.0) 1.0 / math.sqrt(factor(i)) else 0.0 + i += 1 + } + + this + } + + /** + * Applies standardization transformation on a vector. + * + * @param vector Vector to be standardized. + * @return Standardized vector. If the variance of a column is zero, it will return default `0.0` + * for the column with zero variance. + */ + override def transform(vector: Vector): Vector = { + if (mean == null || factor == null) { + throw new IllegalStateException( + "Haven't learned column summary statistics yet. Call fit first.") + } + + require(vector.size == mean.length) + + if (withMean) { + vector.toBreeze match { + case dv: BDV[Double] => + val output = vector.toBreeze.copy + var i = 0 + while (i < output.length) { + output(i) = (output(i) - mean(i)) * (if (withStd) factor(i) else 1.0) + i += 1 + } + Vectors.fromBreeze(output) + case v => throw new IllegalArgumentException("Do not support vector type " + v.getClass) + } + } else if (withStd) { + vector.toBreeze match { + case dv: BDV[Double] => Vectors.fromBreeze(dv :* factor) + case sv: BSV[Double] => + // For sparse vector, the `index` array inside sparse vector object will not be changed, + // so we can re-use it to save memory. + val output = new BSV[Double](sv.index, sv.data.clone(), sv.length) + var i = 0 + while (i < output.data.length) { + output.data(i) *= factor(output.index(i)) + i += 1 + } + Vectors.fromBreeze(output) + case v => throw new IllegalArgumentException("Do not support vector type " + v.getClass) + } + } else { + // Note that it's safe since we always assume that the data in RDD should be immutable. + vector + } + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/VectorTransformer.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/VectorTransformer.scala new file mode 100644 index 0000000000000..415a845332d45 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/VectorTransformer.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.feature + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.rdd.RDD + +/** + * :: DeveloperApi :: + * Trait for transformation of a vector + */ +@DeveloperApi +trait VectorTransformer extends Serializable { + + /** + * Applies transformation on a vector. + * + * @param vector vector to be transformed. + * @return transformed vector. + */ + def transform(vector: Vector): Vector + + /** + * Applies transformation on an RDD[Vector]. + * + * @param data RDD[Vector] to be transformed. + * @return transformed RDD[Vector]. + */ + def transform(data: RDD[Vector]): RDD[Vector] = { + // Later in #1498 , all RDD objects are sent via broadcasting instead of akka. + // So it should be no longer necessary to explicitly broadcast `this` object. + data.map(x => this.transform(x)) + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala new file mode 100644 index 0000000000000..3bf44ad7c44e3 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -0,0 +1,440 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.feature + +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer + +import com.github.fommil.netlib.BLAS.{getInstance => blas} + +import org.apache.spark.Logging +import org.apache.spark.SparkContext._ +import org.apache.spark.annotation.Experimental +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.rdd.RDDFunctions._ +import org.apache.spark.rdd._ +import org.apache.spark.util.Utils +import org.apache.spark.util.random.XORShiftRandom + +/** + * Entry in vocabulary + */ +private case class VocabWord( + var word: String, + var cn: Int, + var point: Array[Int], + var code: Array[Int], + var codeLen:Int +) + +/** + * :: Experimental :: + * Word2Vec creates vector representation of words in a text corpus. + * The algorithm first constructs a vocabulary from the corpus + * and then learns vector representation of words in the vocabulary. + * The vector representation can be used as features in + * natural language processing and machine learning algorithms. + * + * We used skip-gram model in our implementation and hierarchical softmax + * method to train the model. The variable names in the implementation + * matches the original C implementation. + * + * For original C implementation, see https://code.google.com/p/word2vec/ + * For research papers, see + * Efficient Estimation of Word Representations in Vector Space + * and + * Distributed Representations of Words and Phrases and their Compositionality. + */ +@Experimental +class Word2Vec extends Serializable with Logging { + + private var vectorSize = 100 + private var startingAlpha = 0.025 + private var numPartitions = 1 + private var numIterations = 1 + private var seed = Utils.random.nextLong() + + /** + * Sets vector size (default: 100). + */ + def setVectorSize(vectorSize: Int): this.type = { + this.vectorSize = vectorSize + this + } + + /** + * Sets initial learning rate (default: 0.025). + */ + def setLearningRate(learningRate: Double): this.type = { + this.startingAlpha = learningRate + this + } + + /** + * Sets number of partitions (default: 1). Use a small number for accuracy. + */ + def setNumPartitions(numPartitions: Int): this.type = { + require(numPartitions > 0, s"numPartitions must be greater than 0 but got $numPartitions") + this.numPartitions = numPartitions + this + } + + /** + * Sets number of iterations (default: 1), which should be smaller than or equal to number of + * partitions. + */ + def setNumIterations(numIterations: Int): this.type = { + this.numIterations = numIterations + this + } + + /** + * Sets random seed (default: a random long integer). + */ + def setSeed(seed: Long): this.type = { + this.seed = seed + this + } + + private val EXP_TABLE_SIZE = 1000 + private val MAX_EXP = 6 + private val MAX_CODE_LENGTH = 40 + private val MAX_SENTENCE_LENGTH = 1000 + private val layer1Size = vectorSize + + /** context words from [-window, window] */ + private val window = 5 + + /** minimum frequency to consider a vocabulary word */ + private val minCount = 5 + + private var trainWordsCount = 0 + private var vocabSize = 0 + private var vocab: Array[VocabWord] = null + private var vocabHash = mutable.HashMap.empty[String, Int] + private var alpha = startingAlpha + + private def learnVocab(words: RDD[String]): Unit = { + vocab = words.map(w => (w, 1)) + .reduceByKey(_ + _) + .map(x => VocabWord( + x._1, + x._2, + new Array[Int](MAX_CODE_LENGTH), + new Array[Int](MAX_CODE_LENGTH), + 0)) + .filter(_.cn >= minCount) + .collect() + .sortWith((a, b) => a.cn > b.cn) + + vocabSize = vocab.length + var a = 0 + while (a < vocabSize) { + vocabHash += vocab(a).word -> a + trainWordsCount += vocab(a).cn + a += 1 + } + logInfo("trainWordsCount = " + trainWordsCount) + } + + private def createExpTable(): Array[Float] = { + val expTable = new Array[Float](EXP_TABLE_SIZE) + var i = 0 + while (i < EXP_TABLE_SIZE) { + val tmp = math.exp((2.0 * i / EXP_TABLE_SIZE - 1.0) * MAX_EXP) + expTable(i) = (tmp / (tmp + 1.0)).toFloat + i += 1 + } + expTable + } + + private def createBinaryTree(): Unit = { + val count = new Array[Long](vocabSize * 2 + 1) + val binary = new Array[Int](vocabSize * 2 + 1) + val parentNode = new Array[Int](vocabSize * 2 + 1) + val code = new Array[Int](MAX_CODE_LENGTH) + val point = new Array[Int](MAX_CODE_LENGTH) + var a = 0 + while (a < vocabSize) { + count(a) = vocab(a).cn + a += 1 + } + while (a < 2 * vocabSize) { + count(a) = 1e9.toInt + a += 1 + } + var pos1 = vocabSize - 1 + var pos2 = vocabSize + + var min1i = 0 + var min2i = 0 + + a = 0 + while (a < vocabSize - 1) { + if (pos1 >= 0) { + if (count(pos1) < count(pos2)) { + min1i = pos1 + pos1 -= 1 + } else { + min1i = pos2 + pos2 += 1 + } + } else { + min1i = pos2 + pos2 += 1 + } + if (pos1 >= 0) { + if (count(pos1) < count(pos2)) { + min2i = pos1 + pos1 -= 1 + } else { + min2i = pos2 + pos2 += 1 + } + } else { + min2i = pos2 + pos2 += 1 + } + count(vocabSize + a) = count(min1i) + count(min2i) + parentNode(min1i) = vocabSize + a + parentNode(min2i) = vocabSize + a + binary(min2i) = 1 + a += 1 + } + // Now assign binary code to each vocabulary word + var i = 0 + a = 0 + while (a < vocabSize) { + var b = a + i = 0 + while (b != vocabSize * 2 - 2) { + code(i) = binary(b) + point(i) = b + i += 1 + b = parentNode(b) + } + vocab(a).codeLen = i + vocab(a).point(0) = vocabSize - 2 + b = 0 + while (b < i) { + vocab(a).code(i - b - 1) = code(b) + vocab(a).point(i - b) = point(b) - vocabSize + b += 1 + } + a += 1 + } + } + + /** + * Computes the vector representation of each word in vocabulary. + * @param dataset an RDD of words + * @return a Word2VecModel + */ + def fit[S <: Iterable[String]](dataset: RDD[S]): Word2VecModel = { + + val words = dataset.flatMap(x => x) + + learnVocab(words) + + createBinaryTree() + + val sc = dataset.context + + val expTable = sc.broadcast(createExpTable()) + val bcVocab = sc.broadcast(vocab) + val bcVocabHash = sc.broadcast(vocabHash) + + val sentences: RDD[Array[Int]] = words.mapPartitions { iter => + new Iterator[Array[Int]] { + def hasNext: Boolean = iter.hasNext + + def next(): Array[Int] = { + var sentence = new ArrayBuffer[Int] + var sentenceLength = 0 + while (iter.hasNext && sentenceLength < MAX_SENTENCE_LENGTH) { + val word = bcVocabHash.value.get(iter.next()) + word match { + case Some(w) => + sentence += w + sentenceLength += 1 + case None => + } + } + sentence.toArray + } + } + } + + val newSentences = sentences.repartition(numPartitions).cache() + val initRandom = new XORShiftRandom(seed) + var syn0Global = + Array.fill[Float](vocabSize * layer1Size)((initRandom.nextFloat() - 0.5f) / layer1Size) + var syn1Global = new Array[Float](vocabSize * layer1Size) + + for (k <- 1 to numIterations) { + val partial = newSentences.mapPartitionsWithIndex { case (idx, iter) => + val random = new XORShiftRandom(seed ^ ((idx + 1) << 16) ^ ((-k - 1) << 8)) + val model = iter.foldLeft((syn0Global, syn1Global, 0, 0)) { + case ((syn0, syn1, lastWordCount, wordCount), sentence) => + var lwc = lastWordCount + var wc = wordCount + if (wordCount - lastWordCount > 10000) { + lwc = wordCount + // TODO: discount by iteration? + alpha = + startingAlpha * (1 - numPartitions * wordCount.toDouble / (trainWordsCount + 1)) + if (alpha < startingAlpha * 0.0001) alpha = startingAlpha * 0.0001 + logInfo("wordCount = " + wordCount + ", alpha = " + alpha) + } + wc += sentence.size + var pos = 0 + while (pos < sentence.size) { + val word = sentence(pos) + val b = random.nextInt(window) + // Train Skip-gram + var a = b + while (a < window * 2 + 1 - b) { + if (a != window) { + val c = pos - window + a + if (c >= 0 && c < sentence.size) { + val lastWord = sentence(c) + val l1 = lastWord * layer1Size + val neu1e = new Array[Float](layer1Size) + // Hierarchical softmax + var d = 0 + while (d < bcVocab.value(word).codeLen) { + val l2 = bcVocab.value(word).point(d) * layer1Size + // Propagate hidden -> output + var f = blas.sdot(layer1Size, syn0, l1, 1, syn1, l2, 1) + if (f > -MAX_EXP && f < MAX_EXP) { + val ind = ((f + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2.0)).toInt + f = expTable.value(ind) + val g = ((1 - bcVocab.value(word).code(d) - f) * alpha).toFloat + blas.saxpy(layer1Size, g, syn1, l2, 1, neu1e, 0, 1) + blas.saxpy(layer1Size, g, syn0, l1, 1, syn1, l2, 1) + } + d += 1 + } + blas.saxpy(layer1Size, 1.0f, neu1e, 0, 1, syn0, l1, 1) + } + } + a += 1 + } + pos += 1 + } + (syn0, syn1, lwc, wc) + } + Iterator(model) + } + val (aggSyn0, aggSyn1, _, _) = + partial.treeReduce { case ((syn0_1, syn1_1, lwc_1, wc_1), (syn0_2, syn1_2, lwc_2, wc_2)) => + val n = syn0_1.length + val weight1 = 1.0f * wc_1 / (wc_1 + wc_2) + val weight2 = 1.0f * wc_2 / (wc_1 + wc_2) + blas.sscal(n, weight1, syn0_1, 1) + blas.sscal(n, weight1, syn1_1, 1) + blas.saxpy(n, weight2, syn0_2, 1, syn0_1, 1) + blas.saxpy(n, weight2, syn1_2, 1, syn1_1, 1) + (syn0_1, syn1_1, lwc_1 + lwc_2, wc_1 + wc_2) + } + syn0Global = aggSyn0 + syn1Global = aggSyn1 + } + newSentences.unpersist() + + val word2VecMap = mutable.HashMap.empty[String, Array[Float]] + var i = 0 + while (i < vocabSize) { + val word = bcVocab.value(i).word + val vector = new Array[Float](layer1Size) + Array.copy(syn0Global, i * layer1Size, vector, 0, layer1Size) + word2VecMap += word -> vector + i += 1 + } + + new Word2VecModel(word2VecMap.toMap) + } +} + +/** +* Word2Vec model + */ +class Word2VecModel private[mllib] ( + private val model: Map[String, Array[Float]]) extends Serializable { + + private def cosineSimilarity(v1: Array[Float], v2: Array[Float]): Double = { + require(v1.length == v2.length, "Vectors should have the same length") + val n = v1.length + val norm1 = blas.snrm2(n, v1, 1) + val norm2 = blas.snrm2(n, v2, 1) + if (norm1 == 0 || norm2 == 0) return 0.0 + blas.sdot(n, v1, 1, v2,1) / norm1 / norm2 + } + + /** + * Transforms a word to its vector representation + * @param word a word + * @return vector representation of word + */ + def transform(word: String): Vector = { + model.get(word) match { + case Some(vec) => + Vectors.dense(vec.map(_.toDouble)) + case None => + throw new IllegalStateException(s"$word not in vocabulary") + } + } + + /** + * Transforms an RDD to its vector representation + * @param dataset a an RDD of words + * @return RDD of vector representation + */ + def transform(dataset: RDD[String]): RDD[Vector] = { + dataset.map(word => transform(word)) + } + + /** + * Find synonyms of a word + * @param word a word + * @param num number of synonyms to find + * @return array of (word, similarity) + */ + def findSynonyms(word: String, num: Int): Array[(String, Double)] = { + val vector = transform(word) + findSynonyms(vector,num) + } + + /** + * Find synonyms of the vector representation of a word + * @param vector vector representation of a word + * @param num number of synonyms to find + * @return array of (word, cosineSimilarity) + */ + def findSynonyms(vector: Vector, num: Int): Array[(String, Double)] = { + require(num > 0, "Number of similar words should > 0") + // TODO: optimize top-k + val fVector = vector.toArray.map(_.toFloat) + model.mapValues(vec => cosineSimilarity(fVector, vec)) + .toSeq + .sortBy(- _._2) + .take(num + 1) + .tail + .toArray + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala index 8c2b044ea73f2..45486b2c7d82d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala @@ -19,7 +19,7 @@ package org.apache.spark.mllib.linalg.distributed import java.util.Arrays -import breeze.linalg.{Vector => BV, DenseMatrix => BDM, DenseVector => BDV, SparseVector => BSV} +import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, SparseVector => BSV} import breeze.linalg.{svd => brzSvd, axpy => brzAxpy} import breeze.numerics.{sqrt => brzSqrt} import com.github.fommil.netlib.BLAS.{getInstance => blas} @@ -28,6 +28,7 @@ import org.apache.spark.annotation.Experimental import org.apache.spark.mllib.linalg._ import org.apache.spark.rdd.RDD import org.apache.spark.Logging +import org.apache.spark.mllib.rdd.RDDFunctions._ import org.apache.spark.mllib.stat.{MultivariateOnlineSummarizer, MultivariateStatisticalSummary} /** @@ -79,7 +80,7 @@ class RowMatrix( private[mllib] def multiplyGramianMatrixBy(v: BDV[Double]): BDV[Double] = { val n = numCols().toInt val vbr = rows.context.broadcast(v) - rows.aggregate(BDV.zeros[Double](n))( + rows.treeAggregate(BDV.zeros[Double](n))( seqOp = (U, r) => { val rBrz = r.toBreeze val a = rBrz.dot(vbr.value) @@ -91,9 +92,7 @@ class RowMatrix( s"Do not support vector operation from type ${rBrz.getClass.getName}.") } U - }, - combOp = (U1, U2) => U1 += U2 - ) + }, combOp = (U1, U2) => U1 += U2) } /** @@ -104,13 +103,11 @@ class RowMatrix( val nt: Int = n * (n + 1) / 2 // Compute the upper triangular part of the gram matrix. - val GU = rows.aggregate(new BDV[Double](new Array[Double](nt)))( + val GU = rows.treeAggregate(new BDV[Double](new Array[Double](nt)))( seqOp = (U, v) => { RowMatrix.dspr(1.0, v, U.data) U - }, - combOp = (U1, U2) => U1 += U2 - ) + }, combOp = (U1, U2) => U1 += U2) RowMatrix.triuToFull(n, GU.data) } @@ -290,9 +287,10 @@ class RowMatrix( s"We need at least $mem bytes of memory.") } - val (m, mean) = rows.aggregate[(Long, BDV[Double])]((0L, BDV.zeros[Double](n)))( + val (m, mean) = rows.treeAggregate[(Long, BDV[Double])]((0L, BDV.zeros[Double](n)))( seqOp = (s: (Long, BDV[Double]), v: Vector) => (s._1 + 1L, s._2 += v.toBreeze), - combOp = (s1: (Long, BDV[Double]), s2: (Long, BDV[Double])) => (s1._1 + s2._1, s1._2 += s2._2) + combOp = (s1: (Long, BDV[Double]), s2: (Long, BDV[Double])) => + (s1._1 + s2._1, s1._2 += s2._2) ) updateNumRows(m) @@ -353,10 +351,9 @@ class RowMatrix( * Computes column-wise summary statistics. */ def computeColumnSummaryStatistics(): MultivariateStatisticalSummary = { - val summary = rows.aggregate[MultivariateOnlineSummarizer](new MultivariateOnlineSummarizer)( + val summary = rows.treeAggregate(new MultivariateOnlineSummarizer)( (aggregator, data) => aggregator.add(data), - (aggregator1, aggregator2) => aggregator1.merge(aggregator2) - ) + (aggregator1, aggregator2) => aggregator1.merge(aggregator2)) updateNumRows(summary.count) summary } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala index 679842f831c2a..9d82f011e674a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala @@ -68,9 +68,9 @@ class LogisticGradient extends Gradient { val gradient = brzData * gradientMultiplier val loss = if (label > 0) { - math.log(1 + math.exp(margin)) + math.log1p(math.exp(margin)) // log1p is log(1+p) but more accurate for small p } else { - math.log(1 + math.exp(margin)) - margin + math.log1p(math.exp(margin)) - margin } (Vectors.fromBreeze(gradient), loss) @@ -89,9 +89,9 @@ class LogisticGradient extends Gradient { brzAxpy(gradientMultiplier, brzData, cumGradient.toBreeze) if (label > 0) { - math.log(1 + math.exp(margin)) + math.log1p(math.exp(margin)) } else { - math.log(1 + math.exp(margin)) - margin + math.log1p(math.exp(margin)) - margin } } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala index 9fd760bf78083..a6912056395d7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala @@ -25,6 +25,7 @@ import org.apache.spark.annotation.{Experimental, DeveloperApi} import org.apache.spark.Logging import org.apache.spark.rdd.RDD import org.apache.spark.mllib.linalg.{Vectors, Vector} +import org.apache.spark.mllib.rdd.RDDFunctions._ /** * Class used to solve an optimization problem using Gradient Descent. @@ -161,6 +162,14 @@ object GradientDescent extends Logging { val numExamples = data.count() val miniBatchSize = numExamples * miniBatchFraction + // if no data, return initial weights to avoid NaNs + if (numExamples == 0) { + + logInfo("GradientDescent.runMiniBatchSGD returning initial weights, no data found") + return (initialWeights, stochasticLossHistory.toArray) + + } + // Initialize weights as a column vector var weights = Vectors.dense(initialWeights.toArray) val n = weights.size @@ -177,7 +186,7 @@ object GradientDescent extends Logging { // Sample a subset (fraction miniBatchFraction) of the total data // compute and sum up the subgradients on this subset (this is one map-reduce) val (gradientSum, lossSum) = data.sample(false, miniBatchFraction, 42 + i) - .aggregate((BDV.zeros[Double](n), 0.0))( + .treeAggregate((BDV.zeros[Double](n), 0.0))( seqOp = (c, v) => (c, v) match { case ((grad, loss), (label, features)) => val l = gradient.compute(features, label, bcWeights.value, Vectors.fromBreeze(grad)) (grad, loss + l) @@ -201,5 +210,6 @@ object GradientDescent extends Logging { stochasticLossHistory.takeRight(10).mkString(", "))) (weights, stochasticLossHistory.toArray) + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala index 179cd4a3f1625..26a2b62e76ed0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala @@ -26,6 +26,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.Logging import org.apache.spark.rdd.RDD import org.apache.spark.mllib.linalg.{Vectors, Vector} +import org.apache.spark.mllib.rdd.RDDFunctions._ /** * :: DeveloperApi :: @@ -199,7 +200,7 @@ object LBFGS extends Logging { val n = weights.length val bcWeights = data.context.broadcast(weights) - val (gradientSum, lossSum) = data.aggregate((BDV.zeros[Double](n), 0.0))( + val (gradientSum, lossSum) = data.treeAggregate((BDV.zeros[Double](n), 0.0))( seqOp = (c, v) => (c, v) match { case ((grad, loss), (label, features)) => val l = localGradient.compute( features, label, Vectors.fromBreeze(bcWeights.value), Vectors.fromBreeze(grad)) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/random/DistributionGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala similarity index 80% rename from mllib/src/main/scala/org/apache/spark/mllib/random/DistributionGenerator.scala rename to mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala index 7ecb409c4a91a..9cab49f6ed1f0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/random/DistributionGenerator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala @@ -25,21 +25,21 @@ import org.apache.spark.util.random.{XORShiftRandom, Pseudorandom} /** * :: Experimental :: - * Trait for random number generators that generate i.i.d. values from a distribution. + * Trait for random data generators that generate i.i.d. data. */ @Experimental -trait DistributionGenerator extends Pseudorandom with Serializable { +trait RandomDataGenerator[T] extends Pseudorandom with Serializable { /** - * Returns an i.i.d. sample as a Double from an underlying distribution. + * Returns an i.i.d. sample as a generic type from an underlying distribution. */ - def nextValue(): Double + def nextValue(): T /** - * Returns a copy of the DistributionGenerator with a new instance of the rng object used in the + * Returns a copy of the RandomDataGenerator with a new instance of the rng object used in the * class when applicable for non-locking concurrent usage. */ - def copy(): DistributionGenerator + def copy(): RandomDataGenerator[T] } /** @@ -47,7 +47,7 @@ trait DistributionGenerator extends Pseudorandom with Serializable { * Generates i.i.d. samples from U[0.0, 1.0] */ @Experimental -class UniformGenerator extends DistributionGenerator { +class UniformGenerator extends RandomDataGenerator[Double] { // XORShiftRandom for better performance. Thread safety isn't necessary here. private val random = new XORShiftRandom() @@ -66,7 +66,7 @@ class UniformGenerator extends DistributionGenerator { * Generates i.i.d. samples from the standard normal distribution. */ @Experimental -class StandardNormalGenerator extends DistributionGenerator { +class StandardNormalGenerator extends RandomDataGenerator[Double] { // XORShiftRandom for better performance. Thread safety isn't necessary here. private val random = new XORShiftRandom() @@ -87,7 +87,7 @@ class StandardNormalGenerator extends DistributionGenerator { * @param mean mean for the Poisson distribution. */ @Experimental -class PoissonGenerator(val mean: Double) extends DistributionGenerator { +class PoissonGenerator(val mean: Double) extends RandomDataGenerator[Double] { private var rng = new Poisson(mean, new DRand) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDGenerators.scala b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDGenerators.scala index d7ee2d3f46846..b0a0593223910 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDGenerators.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDGenerators.scala @@ -24,16 +24,21 @@ import org.apache.spark.mllib.rdd.{RandomVectorRDD, RandomRDD} import org.apache.spark.rdd.RDD import org.apache.spark.util.Utils +import scala.reflect.ClassTag + /** * :: Experimental :: - * Generator methods for creating RDDs comprised of i.i.d samples from some distribution. + * Generator methods for creating RDDs comprised of i.i.d. samples from some distribution. */ @Experimental object RandomRDDGenerators { /** * :: Experimental :: - * Generates an RDD comprised of i.i.d samples from the uniform distribution on [0.0, 1.0]. + * Generates an RDD comprised of i.i.d. samples from the uniform distribution on [0.0, 1.0]. + * + * To transform the distribution in the generated RDD from U[0.0, 1.0] to U[a, b], use + * `RandomRDDGenerators.uniformRDD(sc, n, p, seed).map(v => a + (b - a) * v)`. * * @param sc SparkContext used to create the RDD. * @param size Size of the RDD. @@ -49,7 +54,10 @@ object RandomRDDGenerators { /** * :: Experimental :: - * Generates an RDD comprised of i.i.d samples from the uniform distribution on [0.0, 1.0]. + * Generates an RDD comprised of i.i.d. samples from the uniform distribution on [0.0, 1.0]. + * + * To transform the distribution in the generated RDD from U[0.0, 1.0] to U[a, b], use + * `RandomRDDGenerators.uniformRDD(sc, n, p).map(v => a + (b - a) * v)`. * * @param sc SparkContext used to create the RDD. * @param size Size of the RDD. @@ -63,9 +71,12 @@ object RandomRDDGenerators { /** * :: Experimental :: - * Generates an RDD comprised of i.i.d samples from the uniform distribution on [0.0, 1.0]. + * Generates an RDD comprised of i.i.d. samples from the uniform distribution on [0.0, 1.0]. * sc.defaultParallelism used for the number of partitions in the RDD. * + * To transform the distribution in the generated RDD from U[0.0, 1.0] to U[a, b], use + * `RandomRDDGenerators.uniformRDD(sc, n).map(v => a + (b - a) * v)`. + * * @param sc SparkContext used to create the RDD. * @param size Size of the RDD. * @return RDD[Double] comprised of i.i.d. samples ~ U[0.0, 1.0]. @@ -77,7 +88,10 @@ object RandomRDDGenerators { /** * :: Experimental :: - * Generates an RDD comprised of i.i.d samples from the standard normal distribution. + * Generates an RDD comprised of i.i.d. samples from the standard normal distribution. + * + * To transform the distribution in the generated RDD from standard normal to some other normal + * N(mean, sigma), use `RandomRDDGenerators.normalRDD(sc, n, p, seed).map(v => mean + sigma * v)`. * * @param sc SparkContext used to create the RDD. * @param size Size of the RDD. @@ -93,7 +107,10 @@ object RandomRDDGenerators { /** * :: Experimental :: - * Generates an RDD comprised of i.i.d samples from the standard normal distribution. + * Generates an RDD comprised of i.i.d. samples from the standard normal distribution. + * + * To transform the distribution in the generated RDD from standard normal to some other normal + * N(mean, sigma), use `RandomRDDGenerators.normalRDD(sc, n, p).map(v => mean + sigma * v)`. * * @param sc SparkContext used to create the RDD. * @param size Size of the RDD. @@ -107,9 +124,12 @@ object RandomRDDGenerators { /** * :: Experimental :: - * Generates an RDD comprised of i.i.d samples from the standard normal distribution. + * Generates an RDD comprised of i.i.d. samples from the standard normal distribution. * sc.defaultParallelism used for the number of partitions in the RDD. * + * To transform the distribution in the generated RDD from standard normal to some other normal + * N(mean, sigma), use `RandomRDDGenerators.normalRDD(sc, n).map(v => mean + sigma * v)`. + * * @param sc SparkContext used to create the RDD. * @param size Size of the RDD. * @return RDD[Double] comprised of i.i.d. samples ~ N(0.0, 1.0). @@ -121,7 +141,7 @@ object RandomRDDGenerators { /** * :: Experimental :: - * Generates an RDD comprised of i.i.d samples from the Poisson distribution with the input mean. + * Generates an RDD comprised of i.i.d. samples from the Poisson distribution with the input mean. * * @param sc SparkContext used to create the RDD. * @param mean Mean, or lambda, for the Poisson distribution. @@ -142,7 +162,7 @@ object RandomRDDGenerators { /** * :: Experimental :: - * Generates an RDD comprised of i.i.d samples from the Poisson distribution with the input mean. + * Generates an RDD comprised of i.i.d. samples from the Poisson distribution with the input mean. * * @param sc SparkContext used to create the RDD. * @param mean Mean, or lambda, for the Poisson distribution. @@ -157,7 +177,7 @@ object RandomRDDGenerators { /** * :: Experimental :: - * Generates an RDD comprised of i.i.d samples from the Poisson distribution with the input mean. + * Generates an RDD comprised of i.i.d. samples from the Poisson distribution with the input mean. * sc.defaultParallelism used for the number of partitions in the RDD. * * @param sc SparkContext used to create the RDD. @@ -172,7 +192,7 @@ object RandomRDDGenerators { /** * :: Experimental :: - * Generates an RDD comprised of i.i.d samples produced by the input DistributionGenerator. + * Generates an RDD comprised of i.i.d. samples produced by the input DistributionGenerator. * * @param sc SparkContext used to create the RDD. * @param generator DistributionGenerator used to populate the RDD. @@ -182,17 +202,17 @@ object RandomRDDGenerators { * @return RDD[Double] comprised of i.i.d. samples produced by generator. */ @Experimental - def randomRDD(sc: SparkContext, - generator: DistributionGenerator, + def randomRDD[T: ClassTag](sc: SparkContext, + generator: RandomDataGenerator[T], size: Long, numPartitions: Int, - seed: Long): RDD[Double] = { - new RandomRDD(sc, size, numPartitions, generator, seed) + seed: Long): RDD[T] = { + new RandomRDD[T](sc, size, numPartitions, generator, seed) } /** * :: Experimental :: - * Generates an RDD comprised of i.i.d samples produced by the input DistributionGenerator. + * Generates an RDD comprised of i.i.d. samples produced by the input DistributionGenerator. * * @param sc SparkContext used to create the RDD. * @param generator DistributionGenerator used to populate the RDD. @@ -201,16 +221,16 @@ object RandomRDDGenerators { * @return RDD[Double] comprised of i.i.d. samples produced by generator. */ @Experimental - def randomRDD(sc: SparkContext, - generator: DistributionGenerator, + def randomRDD[T: ClassTag](sc: SparkContext, + generator: RandomDataGenerator[T], size: Long, - numPartitions: Int): RDD[Double] = { - randomRDD(sc, generator, size, numPartitions, Utils.random.nextLong) + numPartitions: Int): RDD[T] = { + randomRDD[T](sc, generator, size, numPartitions, Utils.random.nextLong) } /** * :: Experimental :: - * Generates an RDD comprised of i.i.d samples produced by the input DistributionGenerator. + * Generates an RDD comprised of i.i.d. samples produced by the input DistributionGenerator. * sc.defaultParallelism used for the number of partitions in the RDD. * * @param sc SparkContext used to create the RDD. @@ -219,17 +239,17 @@ object RandomRDDGenerators { * @return RDD[Double] comprised of i.i.d. samples produced by generator. */ @Experimental - def randomRDD(sc: SparkContext, - generator: DistributionGenerator, - size: Long): RDD[Double] = { - randomRDD(sc, generator, size, sc.defaultParallelism, Utils.random.nextLong) + def randomRDD[T: ClassTag](sc: SparkContext, + generator: RandomDataGenerator[T], + size: Long): RDD[T] = { + randomRDD[T](sc, generator, size, sc.defaultParallelism, Utils.random.nextLong) } // TODO Generate RDD[Vector] from multivariate distributions. /** * :: Experimental :: - * Generates an RDD[Vector] with vectors containing i.i.d samples drawn from the + * Generates an RDD[Vector] with vectors containing i.i.d. samples drawn from the * uniform distribution on [0.0 1.0]. * * @param sc SparkContext used to create the RDD. @@ -251,14 +271,14 @@ object RandomRDDGenerators { /** * :: Experimental :: - * Generates an RDD[Vector] with vectors containing i.i.d samples drawn from the + * Generates an RDD[Vector] with vectors containing i.i.d. samples drawn from the * uniform distribution on [0.0 1.0]. * * @param sc SparkContext used to create the RDD. * @param numRows Number of Vectors in the RDD. * @param numCols Number of elements in each Vector. * @param numPartitions Number of partitions in the RDD. - * @return RDD[Vector] with vectors containing i.i.d samples ~ U[0.0, 1.0]. + * @return RDD[Vector] with vectors containing i.i.d. samples ~ U[0.0, 1.0]. */ @Experimental def uniformVectorRDD(sc: SparkContext, @@ -270,14 +290,14 @@ object RandomRDDGenerators { /** * :: Experimental :: - * Generates an RDD[Vector] with vectors containing i.i.d samples drawn from the + * Generates an RDD[Vector] with vectors containing i.i.d. samples drawn from the * uniform distribution on [0.0 1.0]. * sc.defaultParallelism used for the number of partitions in the RDD. * * @param sc SparkContext used to create the RDD. * @param numRows Number of Vectors in the RDD. * @param numCols Number of elements in each Vector. - * @return RDD[Vector] with vectors containing i.i.d samples ~ U[0.0, 1.0]. + * @return RDD[Vector] with vectors containing i.i.d. samples ~ U[0.0, 1.0]. */ @Experimental def uniformVectorRDD(sc: SparkContext, numRows: Long, numCols: Int): RDD[Vector] = { @@ -286,7 +306,7 @@ object RandomRDDGenerators { /** * :: Experimental :: - * Generates an RDD[Vector] with vectors containing i.i.d samples drawn from the + * Generates an RDD[Vector] with vectors containing i.i.d. samples drawn from the * standard normal distribution. * * @param sc SparkContext used to create the RDD. @@ -294,7 +314,7 @@ object RandomRDDGenerators { * @param numCols Number of elements in each Vector. * @param numPartitions Number of partitions in the RDD. * @param seed Seed for the RNG that generates the seed for the generator in each partition. - * @return RDD[Vector] with vectors containing i.i.d samples ~ N(0.0, 1.0). + * @return RDD[Vector] with vectors containing i.i.d. samples ~ N(0.0, 1.0). */ @Experimental def normalVectorRDD(sc: SparkContext, @@ -308,14 +328,14 @@ object RandomRDDGenerators { /** * :: Experimental :: - * Generates an RDD[Vector] with vectors containing i.i.d samples drawn from the + * Generates an RDD[Vector] with vectors containing i.i.d. samples drawn from the * standard normal distribution. * * @param sc SparkContext used to create the RDD. * @param numRows Number of Vectors in the RDD. * @param numCols Number of elements in each Vector. * @param numPartitions Number of partitions in the RDD. - * @return RDD[Vector] with vectors containing i.i.d samples ~ N(0.0, 1.0). + * @return RDD[Vector] with vectors containing i.i.d. samples ~ N(0.0, 1.0). */ @Experimental def normalVectorRDD(sc: SparkContext, @@ -327,14 +347,14 @@ object RandomRDDGenerators { /** * :: Experimental :: - * Generates an RDD[Vector] with vectors containing i.i.d samples drawn from the + * Generates an RDD[Vector] with vectors containing i.i.d. samples drawn from the * standard normal distribution. * sc.defaultParallelism used for the number of partitions in the RDD. * * @param sc SparkContext used to create the RDD. * @param numRows Number of Vectors in the RDD. * @param numCols Number of elements in each Vector. - * @return RDD[Vector] with vectors containing i.i.d samples ~ N(0.0, 1.0). + * @return RDD[Vector] with vectors containing i.i.d. samples ~ N(0.0, 1.0). */ @Experimental def normalVectorRDD(sc: SparkContext, numRows: Long, numCols: Int): RDD[Vector] = { @@ -343,7 +363,7 @@ object RandomRDDGenerators { /** * :: Experimental :: - * Generates an RDD[Vector] with vectors containing i.i.d samples drawn from the + * Generates an RDD[Vector] with vectors containing i.i.d. samples drawn from the * Poisson distribution with the input mean. * * @param sc SparkContext used to create the RDD. @@ -352,7 +372,7 @@ object RandomRDDGenerators { * @param numCols Number of elements in each Vector. * @param numPartitions Number of partitions in the RDD. * @param seed Seed for the RNG that generates the seed for the generator in each partition. - * @return RDD[Vector] with vectors containing i.i.d samples ~ Pois(mean). + * @return RDD[Vector] with vectors containing i.i.d. samples ~ Pois(mean). */ @Experimental def poissonVectorRDD(sc: SparkContext, @@ -367,7 +387,7 @@ object RandomRDDGenerators { /** * :: Experimental :: - * Generates an RDD[Vector] with vectors containing i.i.d samples drawn from the + * Generates an RDD[Vector] with vectors containing i.i.d. samples drawn from the * Poisson distribution with the input mean. * * @param sc SparkContext used to create the RDD. @@ -375,7 +395,7 @@ object RandomRDDGenerators { * @param numRows Number of Vectors in the RDD. * @param numCols Number of elements in each Vector. * @param numPartitions Number of partitions in the RDD. - * @return RDD[Vector] with vectors containing i.i.d samples ~ Pois(mean). + * @return RDD[Vector] with vectors containing i.i.d. samples ~ Pois(mean). */ @Experimental def poissonVectorRDD(sc: SparkContext, @@ -388,7 +408,7 @@ object RandomRDDGenerators { /** * :: Experimental :: - * Generates an RDD[Vector] with vectors containing i.i.d samples drawn from the + * Generates an RDD[Vector] with vectors containing i.i.d. samples drawn from the * Poisson distribution with the input mean. * sc.defaultParallelism used for the number of partitions in the RDD. * @@ -396,7 +416,7 @@ object RandomRDDGenerators { * @param mean Mean, or lambda, for the Poisson distribution. * @param numRows Number of Vectors in the RDD. * @param numCols Number of elements in each Vector. - * @return RDD[Vector] with vectors containing i.i.d samples ~ Pois(mean). + * @return RDD[Vector] with vectors containing i.i.d. samples ~ Pois(mean). */ @Experimental def poissonVectorRDD(sc: SparkContext, @@ -408,7 +428,7 @@ object RandomRDDGenerators { /** * :: Experimental :: - * Generates an RDD[Vector] with vectors containing i.i.d samples produced by the + * Generates an RDD[Vector] with vectors containing i.i.d. samples produced by the * input DistributionGenerator. * * @param sc SparkContext used to create the RDD. @@ -417,11 +437,11 @@ object RandomRDDGenerators { * @param numCols Number of elements in each Vector. * @param numPartitions Number of partitions in the RDD. * @param seed Seed for the RNG that generates the seed for the generator in each partition. - * @return RDD[Vector] with vectors containing i.i.d samples produced by generator. + * @return RDD[Vector] with vectors containing i.i.d. samples produced by generator. */ @Experimental def randomVectorRDD(sc: SparkContext, - generator: DistributionGenerator, + generator: RandomDataGenerator[Double], numRows: Long, numCols: Int, numPartitions: Int, @@ -431,7 +451,7 @@ object RandomRDDGenerators { /** * :: Experimental :: - * Generates an RDD[Vector] with vectors containing i.i.d samples produced by the + * Generates an RDD[Vector] with vectors containing i.i.d. samples produced by the * input DistributionGenerator. * * @param sc SparkContext used to create the RDD. @@ -439,11 +459,11 @@ object RandomRDDGenerators { * @param numRows Number of Vectors in the RDD. * @param numCols Number of elements in each Vector. * @param numPartitions Number of partitions in the RDD. - * @return RDD[Vector] with vectors containing i.i.d samples produced by generator. + * @return RDD[Vector] with vectors containing i.i.d. samples produced by generator. */ @Experimental def randomVectorRDD(sc: SparkContext, - generator: DistributionGenerator, + generator: RandomDataGenerator[Double], numRows: Long, numCols: Int, numPartitions: Int): RDD[Vector] = { @@ -452,7 +472,7 @@ object RandomRDDGenerators { /** * :: Experimental :: - * Generates an RDD[Vector] with vectors containing i.i.d samples produced by the + * Generates an RDD[Vector] with vectors containing i.i.d. samples produced by the * input DistributionGenerator. * sc.defaultParallelism used for the number of partitions in the RDD. * @@ -460,11 +480,11 @@ object RandomRDDGenerators { * @param generator DistributionGenerator used to populate the RDD. * @param numRows Number of Vectors in the RDD. * @param numCols Number of elements in each Vector. - * @return RDD[Vector] with vectors containing i.i.d samples produced by generator. + * @return RDD[Vector] with vectors containing i.i.d. samples produced by generator. */ @Experimental def randomVectorRDD(sc: SparkContext, - generator: DistributionGenerator, + generator: RandomDataGenerator[Double], numRows: Long, numCols: Int): RDD[Vector] = { randomVectorRDD(sc, generator, numRows, numCols, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala index 365b5e75d7f75..b5e403bc8c14d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala @@ -20,7 +20,10 @@ package org.apache.spark.mllib.rdd import scala.language.implicitConversions import scala.reflect.ClassTag +import org.apache.spark.HashPartitioner +import org.apache.spark.SparkContext._ import org.apache.spark.rdd.RDD +import org.apache.spark.util.Utils /** * Machine learning specific RDD functions. @@ -44,6 +47,69 @@ class RDDFunctions[T: ClassTag](self: RDD[T]) { new SlidingRDD[T](self, windowSize) } } + + /** + * Reduces the elements of this RDD in a multi-level tree pattern. + * + * @param depth suggested depth of the tree (default: 2) + * @see [[org.apache.spark.rdd.RDD#reduce]] + */ + def treeReduce(f: (T, T) => T, depth: Int = 2): T = { + require(depth >= 1, s"Depth must be greater than or equal to 1 but got $depth.") + val cleanF = self.context.clean(f) + val reducePartition: Iterator[T] => Option[T] = iter => { + if (iter.hasNext) { + Some(iter.reduceLeft(cleanF)) + } else { + None + } + } + val partiallyReduced = self.mapPartitions(it => Iterator(reducePartition(it))) + val op: (Option[T], Option[T]) => Option[T] = (c, x) => { + if (c.isDefined && x.isDefined) { + Some(cleanF(c.get, x.get)) + } else if (c.isDefined) { + c + } else if (x.isDefined) { + x + } else { + None + } + } + RDDFunctions.fromRDD(partiallyReduced).treeAggregate(Option.empty[T])(op, op, depth) + .getOrElse(throw new UnsupportedOperationException("empty collection")) + } + + /** + * Aggregates the elements of this RDD in a multi-level tree pattern. + * + * @param depth suggested depth of the tree (default: 2) + * @see [[org.apache.spark.rdd.RDD#aggregate]] + */ + def treeAggregate[U: ClassTag](zeroValue: U)( + seqOp: (U, T) => U, + combOp: (U, U) => U, + depth: Int = 2): U = { + require(depth >= 1, s"Depth must be greater than or equal to 1 but got $depth.") + if (self.partitions.size == 0) { + return Utils.clone(zeroValue, self.context.env.closureSerializer.newInstance()) + } + val cleanSeqOp = self.context.clean(seqOp) + val cleanCombOp = self.context.clean(combOp) + val aggregatePartition = (it: Iterator[T]) => it.aggregate(zeroValue)(cleanSeqOp, cleanCombOp) + var partiallyAggregated = self.mapPartitions(it => Iterator(aggregatePartition(it))) + var numPartitions = partiallyAggregated.partitions.size + val scale = math.max(math.ceil(math.pow(numPartitions, 1.0 / depth)).toInt, 2) + // If creating an extra level doesn't help reduce the wall-clock time, we stop tree aggregation. + while (numPartitions > scale + numPartitions / scale) { + numPartitions /= scale + val curNumPartitions = numPartitions + partiallyAggregated = partiallyAggregated.mapPartitionsWithIndex { (i, iter) => + iter.map((i % curNumPartitions, _)) + }.reduceByKey(new HashPartitioner(curNumPartitions), cleanCombOp).values + } + partiallyAggregated.reduce(cleanCombOp) + } } private[mllib] diff --git a/mllib/src/main/scala/org/apache/spark/mllib/rdd/RandomRDD.scala b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RandomRDD.scala index f13282d07ff92..c8db3910c6eab 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/rdd/RandomRDD.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RandomRDD.scala @@ -19,35 +19,36 @@ package org.apache.spark.mllib.rdd import org.apache.spark.{Partition, SparkContext, TaskContext} import org.apache.spark.mllib.linalg.{DenseVector, Vector} -import org.apache.spark.mllib.random.DistributionGenerator +import org.apache.spark.mllib.random.RandomDataGenerator import org.apache.spark.rdd.RDD import org.apache.spark.util.Utils +import scala.reflect.ClassTag import scala.util.Random -private[mllib] class RandomRDDPartition(override val index: Int, +private[mllib] class RandomRDDPartition[T](override val index: Int, val size: Int, - val generator: DistributionGenerator, + val generator: RandomDataGenerator[T], val seed: Long) extends Partition { require(size >= 0, "Non-negative partition size required.") } // These two classes are necessary since Range objects in Scala cannot have size > Int.MaxValue -private[mllib] class RandomRDD(@transient sc: SparkContext, +private[mllib] class RandomRDD[T: ClassTag](@transient sc: SparkContext, size: Long, numPartitions: Int, - @transient rng: DistributionGenerator, - @transient seed: Long = Utils.random.nextLong) extends RDD[Double](sc, Nil) { + @transient rng: RandomDataGenerator[T], + @transient seed: Long = Utils.random.nextLong) extends RDD[T](sc, Nil) { require(size > 0, "Positive RDD size required.") require(numPartitions > 0, "Positive number of partitions required") require(math.ceil(size.toDouble / numPartitions) <= Int.MaxValue, "Partition size cannot exceed Int.MaxValue") - override def compute(splitIn: Partition, context: TaskContext): Iterator[Double] = { - val split = splitIn.asInstanceOf[RandomRDDPartition] - RandomRDD.getPointIterator(split) + override def compute(splitIn: Partition, context: TaskContext): Iterator[T] = { + val split = splitIn.asInstanceOf[RandomRDDPartition[T]] + RandomRDD.getPointIterator[T](split) } override def getPartitions: Array[Partition] = { @@ -59,7 +60,7 @@ private[mllib] class RandomVectorRDD(@transient sc: SparkContext, size: Long, vectorSize: Int, numPartitions: Int, - @transient rng: DistributionGenerator, + @transient rng: RandomDataGenerator[Double], @transient seed: Long = Utils.random.nextLong) extends RDD[Vector](sc, Nil) { require(size > 0, "Positive RDD size required.") @@ -69,7 +70,7 @@ private[mllib] class RandomVectorRDD(@transient sc: SparkContext, "Partition size cannot exceed Int.MaxValue") override def compute(splitIn: Partition, context: TaskContext): Iterator[Vector] = { - val split = splitIn.asInstanceOf[RandomRDDPartition] + val split = splitIn.asInstanceOf[RandomRDDPartition[Double]] RandomRDD.getVectorIterator(split, vectorSize) } @@ -80,12 +81,12 @@ private[mllib] class RandomVectorRDD(@transient sc: SparkContext, private[mllib] object RandomRDD { - def getPartitions(size: Long, + def getPartitions[T](size: Long, numPartitions: Int, - rng: DistributionGenerator, + rng: RandomDataGenerator[T], seed: Long): Array[Partition] = { - val partitions = new Array[RandomRDDPartition](numPartitions) + val partitions = new Array[RandomRDDPartition[T]](numPartitions) var i = 0 var start: Long = 0 var end: Long = 0 @@ -101,7 +102,7 @@ private[mllib] object RandomRDD { // The RNG has to be reset every time the iterator is requested to guarantee same data // every time the content of the RDD is examined. - def getPointIterator(partition: RandomRDDPartition): Iterator[Double] = { + def getPointIterator[T: ClassTag](partition: RandomRDDPartition[T]): Iterator[T] = { val generator = partition.generator.copy() generator.setSeed(partition.seed) Array.fill(partition.size)(generator.nextValue()).toIterator @@ -109,7 +110,8 @@ private[mllib] object RandomRDD { // The RNG has to be reset every time the iterator is requested to guarantee same data // every time the content of the RDD is examined. - def getVectorIterator(partition: RandomRDDPartition, vectorSize: Int): Iterator[Vector] = { + def getVectorIterator(partition: RandomRDDPartition[Double], + vectorSize: Int): Iterator[Vector] = { val generator = partition.generator.copy() generator.setSeed(partition.seed) Array.fill(partition.size)(new DenseVector( diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala index 5356790cb5339..8ebc7e27ed4dd 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala @@ -17,7 +17,8 @@ package org.apache.spark.mllib.recommendation -import scala.collection.mutable.{ArrayBuffer, BitSet} +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer import scala.math.{abs, sqrt} import scala.util.Random import scala.util.Sorting @@ -25,7 +26,7 @@ import scala.util.hashing.byteswap32 import org.jblas.{DoubleMatrix, SimpleBlas, Solve} -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.broadcast.Broadcast import org.apache.spark.{Logging, HashPartitioner, Partitioner} import org.apache.spark.storage.StorageLevel @@ -39,7 +40,8 @@ import org.apache.spark.mllib.optimization.NNLS * of the elements within this block, and the list of destination blocks that each user or * product will need to send its feature vector to. */ -private[recommendation] case class OutLinkBlock(elementIds: Array[Int], shouldSend: Array[BitSet]) +private[recommendation] +case class OutLinkBlock(elementIds: Array[Int], shouldSend: Array[mutable.BitSet]) /** @@ -255,6 +257,9 @@ class ALS private ( rank, lambda, alpha, YtY) previousProducts.unpersist() logInfo("Re-computing U given I (Iteration %d/%d)".format(iter, iterations)) + if (sc.checkpointDir.isDefined && (iter % 3 == 0)) { + products.checkpoint() + } products.setName(s"products-$iter").persist() val XtX = Some(sc.broadcast(computeYtY(products))) val previousUsers = users @@ -268,6 +273,9 @@ class ALS private ( logInfo("Re-computing I given U (Iteration %d/%d)".format(iter, iterations)) products = updateFeatures(numProductBlocks, users, userOutLinks, productInLinks, rank, lambda, alpha, YtY = None) + if (sc.checkpointDir.isDefined && (iter % 3 == 0)) { + products.checkpoint() + } products.setName(s"products-$iter") logInfo("Re-computing U given I (Iteration %d/%d)".format(iter, iterations)) users = updateFeatures(numUserBlocks, products, productOutLinks, userInLinks, @@ -284,8 +292,8 @@ class ALS private ( val usersOut = unblockFactors(users, userOutLinks) val productsOut = unblockFactors(products, productOutLinks) - usersOut.setName("usersOut").persist() - productsOut.setName("productsOut").persist() + usersOut.setName("usersOut").persist(StorageLevel.MEMORY_AND_DISK) + productsOut.setName("productsOut").persist(StorageLevel.MEMORY_AND_DISK) // Materialize usersOut and productsOut. usersOut.count() @@ -376,7 +384,7 @@ class ALS private ( val userIds = ratings.map(_.user).distinct.sorted val numUsers = userIds.length val userIdToPos = userIds.zipWithIndex.toMap - val shouldSend = Array.fill(numUsers)(new BitSet(numProductBlocks)) + val shouldSend = Array.fill(numUsers)(new mutable.BitSet(numProductBlocks)) for (r <- ratings) { shouldSend(userIdToPos(r.user))(productPartitioner.getPartition(r.product)) = true } @@ -791,4 +799,120 @@ object ALS { : MatrixFactorizationModel = { trainImplicit(ratings, rank, iterations, 0.01, -1, 1.0) } + + /** + * :: DeveloperApi :: + * Statistics of a block in ALS computation. + * + * @param category type of this block, "user" or "product" + * @param index index of this block + * @param count number of users or products inside this block, the same as the number of + * least-squares problems to solve on this block in each iteration + * @param numRatings total number of ratings inside this block, the same as the number of outer + * products we need to make on this block in each iteration + * @param numInLinks total number of incoming links, the same as the number of vectors to retrieve + * before each iteration + * @param numOutLinks total number of outgoing links, the same as the number of vectors to send + * for the next iteration + */ + @DeveloperApi + case class BlockStats( + category: String, + index: Int, + count: Long, + numRatings: Long, + numInLinks: Long, + numOutLinks: Long) + + /** + * :: DeveloperApi :: + * Given an RDD of ratings, number of user blocks, and number of product blocks, computes the + * statistics of each block in ALS computation. This is useful for estimating cost and diagnosing + * load balance. + * + * @param ratings an RDD of ratings + * @param numUserBlocks number of user blocks + * @param numProductBlocks number of product blocks + * @return statistics of user blocks and product blocks + */ + @DeveloperApi + def analyzeBlocks( + ratings: RDD[Rating], + numUserBlocks: Int, + numProductBlocks: Int): Array[BlockStats] = { + + val userPartitioner = new ALSPartitioner(numUserBlocks) + val productPartitioner = new ALSPartitioner(numProductBlocks) + + val ratingsByUserBlock = ratings.map { rating => + (userPartitioner.getPartition(rating.user), rating) + } + val ratingsByProductBlock = ratings.map { rating => + (productPartitioner.getPartition(rating.product), + Rating(rating.product, rating.user, rating.rating)) + } + + val als = new ALS() + val (userIn, userOut) = + als.makeLinkRDDs(numUserBlocks, numProductBlocks, ratingsByUserBlock, userPartitioner) + val (prodIn, prodOut) = + als.makeLinkRDDs(numProductBlocks, numUserBlocks, ratingsByProductBlock, productPartitioner) + + def sendGrid(outLinks: RDD[(Int, OutLinkBlock)]): Map[(Int, Int), Long] = { + outLinks.map { x => + val grid = new mutable.HashMap[(Int, Int), Long]() + val uPartition = x._1 + x._2.shouldSend.foreach { ss => + ss.foreach { pPartition => + val pair = (uPartition, pPartition) + grid.put(pair, grid.getOrElse(pair, 0L) + 1L) + } + } + grid + }.reduce { (grid1, grid2) => + grid2.foreach { x => + grid1.put(x._1, grid1.getOrElse(x._1, 0L) + x._2) + } + grid1 + }.toMap + } + + val userSendGrid = sendGrid(userOut) + val prodSendGrid = sendGrid(prodOut) + + val userInbound = new Array[Long](numUserBlocks) + val prodInbound = new Array[Long](numProductBlocks) + val userOutbound = new Array[Long](numUserBlocks) + val prodOutbound = new Array[Long](numProductBlocks) + + for (u <- 0 until numUserBlocks; p <- 0 until numProductBlocks) { + userOutbound(u) += userSendGrid.getOrElse((u, p), 0L) + prodInbound(p) += userSendGrid.getOrElse((u, p), 0L) + userInbound(u) += prodSendGrid.getOrElse((p, u), 0L) + prodOutbound(p) += prodSendGrid.getOrElse((p, u), 0L) + } + + val userCounts = userOut.mapValues(x => x.elementIds.length).collectAsMap() + val prodCounts = prodOut.mapValues(x => x.elementIds.length).collectAsMap() + + val userRatings = countRatings(userIn) + val prodRatings = countRatings(prodIn) + + val userStats = Array.tabulate(numUserBlocks)( + u => BlockStats("user", u, userCounts(u), userRatings(u), userInbound(u), userOutbound(u))) + val productStatus = Array.tabulate(numProductBlocks)( + p => BlockStats("product", p, prodCounts(p), prodRatings(p), prodInbound(p), prodOutbound(p))) + + (userStats ++ productStatus).toArray + } + + private def countRatings(inLinks: RDD[(Int, InLinkBlock)]): Map[Int, Long] = { + inLinks.mapValues { ilb => + var numRatings = 0L + ilb.ratingsForBlock.foreach { ar => + ar.foreach { p => numRatings += p._1.length } + } + numRatings + }.collectAsMap().toMap + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala index 899286d235a9d..a1a76fcbe9f9c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala @@ -65,6 +65,48 @@ class MatrixFactorizationModel private[mllib] ( } } + /** + * Recommends products to a user. + * + * @param user the user to recommend products to + * @param num how many products to return. The number returned may be less than this. + * @return [[Rating]] objects, each of which contains the given user ID, a product ID, and a + * "score" in the rating field. Each represents one recommended product, and they are sorted + * by score, decreasing. The first returned is the one predicted to be most strongly + * recommended to the user. The score is an opaque value that indicates how strongly + * recommended the product is. + */ + def recommendProducts(user: Int, num: Int): Array[Rating] = + recommend(userFeatures.lookup(user).head, productFeatures, num) + .map(t => Rating(user, t._1, t._2)) + + /** + * Recommends users to a product. That is, this returns users who are most likely to be + * interested in a product. + * + * @param product the product to recommend users to + * @param num how many users to return. The number returned may be less than this. + * @return [[Rating]] objects, each of which contains a user ID, the given product ID, and a + * "score" in the rating field. Each represents one recommended user, and they are sorted + * by score, decreasing. The first returned is the one predicted to be most strongly + * recommended to the product. The score is an opaque value that indicates how strongly + * recommended the user is. + */ + def recommendUsers(product: Int, num: Int): Array[Rating] = + recommend(productFeatures.lookup(product).head, userFeatures, num) + .map(t => Rating(t._1, product, t._2)) + + private def recommend( + recommendToFeatures: Array[Double], + recommendableFeatures: RDD[(Int, Array[Double])], + num: Int): Array[(Int, Double)] = { + val recommendToVector = new DoubleMatrix(recommendToFeatures) + val scored = recommendableFeatures.map { case (id,features) => + (id, recommendToVector.dot(new DoubleMatrix(features))) + } + scored.top(num)(Ordering.by(_._2)) + } + /** * :: DeveloperApi :: * Predict the rating of many users for many products. @@ -80,6 +122,4 @@ class MatrixFactorizationModel private[mllib] ( predict(usersProducts).map(rate => pythonAPI.serializeRating(rate)) } - // TODO: Figure out what other good bulk prediction methods would look like. - // Probably want a way to get the top users for a product or vice-versa. } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala index 8c078ec9f66e9..81b6598377ff5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala @@ -49,7 +49,7 @@ class LinearRegressionModel ( * its corresponding right hand side label y. * See also the documentation for the precise formulation. */ -class LinearRegressionWithSGD private ( +class LinearRegressionWithSGD private[mllib] ( private var stepSize: Double, private var numIterations: Int, private var miniBatchFraction: Double) @@ -68,7 +68,7 @@ class LinearRegressionWithSGD private ( */ def this() = this(1.0, 100, 1.0) - override protected def createModel(weights: Vector, intercept: Double) = { + override protected[mllib] def createModel(weights: Vector, intercept: Double) = { new LinearRegressionModel(weights, intercept) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala new file mode 100644 index 0000000000000..b8b0b42611775 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.regression + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.Logging +import org.apache.spark.streaming.dstream.DStream + +/** + * :: DeveloperApi :: + * StreamingLinearAlgorithm implements methods for continuously + * training a generalized linear model model on streaming data, + * and using it for prediction on (possibly different) streaming data. + * + * This class takes as type parameters a GeneralizedLinearModel, + * and a GeneralizedLinearAlgorithm, making it easy to extend to construct + * streaming versions of any analyses using GLMs. + * Initial weights must be set before calling trainOn or predictOn. + * Only weights will be updated, not an intercept. If the model needs + * an intercept, it should be manually appended to the input data. + * + * For example usage, see `StreamingLinearRegressionWithSGD`. + * + * NOTE(Freeman): In some use cases, the order in which trainOn and predictOn + * are called in an application will affect the results. When called on + * the same DStream, if trainOn is called before predictOn, when new data + * arrive the model will update and the prediction will be based on the new + * model. Whereas if predictOn is called first, the prediction will use the model + * from the previous update. + * + * NOTE(Freeman): It is ok to call predictOn repeatedly on multiple streams; this + * will generate predictions for each one all using the current model. + * It is also ok to call trainOn on different streams; this will update + * the model using each of the different sources, in sequence. + * + */ +@DeveloperApi +abstract class StreamingLinearAlgorithm[ + M <: GeneralizedLinearModel, + A <: GeneralizedLinearAlgorithm[M]] extends Logging { + + /** The model to be updated and used for prediction. */ + protected var model: M + + /** The algorithm to use for updating. */ + protected val algorithm: A + + /** Return the latest model. */ + def latestModel(): M = { + model + } + + /** + * Update the model by training on batches of data from a DStream. + * This operation registers a DStream for training the model, + * and updates the model based on every subsequent + * batch of data from the stream. + * + * @param data DStream containing labeled data + */ + def trainOn(data: DStream[LabeledPoint]) { + if (Option(model.weights) == None) { + logError("Initial weights must be set before starting training") + throw new IllegalArgumentException + } + data.foreachRDD { (rdd, time) => + model = algorithm.run(rdd, model.weights) + logInfo("Model updated at time %s".format(time.toString)) + val display = model.weights.size match { + case x if x > 100 => model.weights.toArray.take(100).mkString("[", ",", "...") + case _ => model.weights.toArray.mkString("[", ",", "]") + } + logInfo("Current model: weights, %s".format (display)) + } + } + + /** + * Use the model to make predictions on batches of data from a DStream + * + * @param data DStream containing labeled data + * @return DStream containing predictions + */ + def predictOn(data: DStream[LabeledPoint]): DStream[Double] = { + if (Option(model.weights) == None) { + logError("Initial weights must be set before starting prediction") + throw new IllegalArgumentException + } + data.map(x => model.predict(x.features)) + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala new file mode 100644 index 0000000000000..8851097050318 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.regression + +import org.apache.spark.annotation.Experimental +import org.apache.spark.mllib.linalg.{Vector, Vectors} + +/** + * Train or predict a linear regression model on streaming data. Training uses + * Stochastic Gradient Descent to update the model based on each new batch of + * incoming data from a DStream (see `LinearRegressionWithSGD` for model equation) + * + * Each batch of data is assumed to be an RDD of LabeledPoints. + * The number of data points per batch can vary, but the number + * of features must be constant. An initial weight + * vector must be provided. + * + * Use a builder pattern to construct a streaming linear regression + * analysis in an application, like: + * + * val model = new StreamingLinearRegressionWithSGD() + * .setStepSize(0.5) + * .setNumIterations(10) + * .setInitialWeights(Vectors.dense(...)) + * .trainOn(DStream) + * + */ +@Experimental +class StreamingLinearRegressionWithSGD ( + private var stepSize: Double, + private var numIterations: Int, + private var miniBatchFraction: Double, + private var initialWeights: Vector) + extends StreamingLinearAlgorithm[ + LinearRegressionModel, LinearRegressionWithSGD] with Serializable { + + /** + * Construct a StreamingLinearRegression object with default parameters: + * {stepSize: 0.1, numIterations: 50, miniBatchFraction: 1.0}. + * Initial weights must be set before using trainOn or predictOn + * (see `StreamingLinearAlgorithm`) + */ + def this() = this(0.1, 50, 1.0, null) + + val algorithm = new LinearRegressionWithSGD(stepSize, numIterations, miniBatchFraction) + + var model = algorithm.createModel(initialWeights, 0.0) + + /** Set the step size for gradient descent. Default: 0.1. */ + def setStepSize(stepSize: Double): this.type = { + this.algorithm.optimizer.setStepSize(stepSize) + this + } + + /** Set the number of iterations of gradient descent to run per update. Default: 50. */ + def setNumIterations(numIterations: Int): this.type = { + this.algorithm.optimizer.setNumIterations(numIterations) + this + } + + /** Set the fraction of each batch to use for updates. Default: 1.0. */ + def setMiniBatchFraction(miniBatchFraction: Double): this.type = { + this.algorithm.optimizer.setMiniBatchFraction(miniBatchFraction) + this + } + + /** Set the initial weights. Default: [0.0, 0.0]. */ + def setInitialWeights(initialWeights: Vector): this.type = { + this.model = algorithm.createModel(initialWeights, 0.0) + this + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala index 68f3867ba6c11..f416a9fbb323d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala @@ -23,23 +23,26 @@ import org.apache.spark.mllib.stat.correlation.Correlations import org.apache.spark.rdd.RDD /** - * API for statistical functions in MLlib + * API for statistical functions in MLlib. */ @Experimental object Statistics { /** + * :: Experimental :: * Compute the Pearson correlation matrix for the input RDD of Vectors. - * Returns NaN if either vector has 0 variance. + * Columns with 0 covariance produce NaN entries in the correlation matrix. * * @param X an RDD[Vector] for which the correlation matrix is to be computed. * @return Pearson correlation matrix comparing columns in X. */ + @Experimental def corr(X: RDD[Vector]): Matrix = Correlations.corrMatrix(X) /** + * :: Experimental :: * Compute the correlation matrix for the input RDD of Vectors using the specified method. - * Methods currently supported: `pearson` (default), `spearman` + * Methods currently supported: `pearson` (default), `spearman`. * * Note that for Spearman, a rank correlation, we need to create an RDD[Double] for each column * and sort it in order to retrieve the ranks and then join the columns back into an RDD[Vector], @@ -51,28 +54,39 @@ object Statistics { * Supported: `pearson` (default), `spearman` * @return Correlation matrix comparing columns in X. */ + @Experimental def corr(X: RDD[Vector], method: String): Matrix = Correlations.corrMatrix(X, method) /** + * :: Experimental :: * Compute the Pearson correlation for the input RDDs. - * Columns with 0 covariance produce NaN entries in the correlation matrix. + * Returns NaN if either vector has 0 variance. + * + * Note: the two input RDDs need to have the same number of partitions and the same number of + * elements in each partition. * - * @param x RDD[Double] of the same cardinality as y - * @param y RDD[Double] of the same cardinality as x + * @param x RDD[Double] of the same cardinality as y. + * @param y RDD[Double] of the same cardinality as x. * @return A Double containing the Pearson correlation between the two input RDD[Double]s */ + @Experimental def corr(x: RDD[Double], y: RDD[Double]): Double = Correlations.corr(x, y) /** + * :: Experimental :: * Compute the correlation for the input RDDs using the specified method. - * Methods currently supported: pearson (default), spearman + * Methods currently supported: `pearson` (default), `spearman`. + * + * Note: the two input RDDs need to have the same number of partitions and the same number of + * elements in each partition. * - * @param x RDD[Double] of the same cardinality as y - * @param y RDD[Double] of the same cardinality as x + * @param x RDD[Double] of the same cardinality as y. + * @param y RDD[Double] of the same cardinality as x. * @param method String specifying the method to use for computing correlation. * Supported: `pearson` (default), `spearman` *@return A Double containing the correlation between the two input RDD[Double]s using the * specified method. */ + @Experimental def corr(x: RDD[Double], y: RDD[Double], method: String): Double = Correlations.corr(x, y, method) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/correlation/Correlation.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/correlation/Correlation.scala index f23393d3da257..1fb8d7b3d4f32 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/correlation/Correlation.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/correlation/Correlation.scala @@ -49,43 +49,48 @@ private[stat] trait Correlation { } /** - * Delegates computation to the specific correlation object based on the input method name - * - * Currently supported correlations: pearson, spearman. - * After new correlation algorithms are added, please update the documentation here and in - * Statistics.scala for the correlation APIs. - * - * Maintains the default correlation type, pearson + * Delegates computation to the specific correlation object based on the input method name. */ private[stat] object Correlations { - // Note: after new types of correlations are implemented, please update this map - val nameToObjectMap = Map(("pearson", PearsonCorrelation), ("spearman", SpearmanCorrelation)) - val defaultCorrName: String = "pearson" - val defaultCorr: Correlation = nameToObjectMap(defaultCorrName) - - def corr(x: RDD[Double], y: RDD[Double], method: String = defaultCorrName): Double = { + def corr(x: RDD[Double], + y: RDD[Double], + method: String = CorrelationNames.defaultCorrName): Double = { val correlation = getCorrelationFromName(method) correlation.computeCorrelation(x, y) } - def corrMatrix(X: RDD[Vector], method: String = defaultCorrName): Matrix = { + def corrMatrix(X: RDD[Vector], + method: String = CorrelationNames.defaultCorrName): Matrix = { val correlation = getCorrelationFromName(method) correlation.computeCorrelationMatrix(X) } - /** - * Match input correlation name with a known name via simple string matching - * - * private to stat for ease of unit testing - */ - private[stat] def getCorrelationFromName(method: String): Correlation = { + // Match input correlation name with a known name via simple string matching. + def getCorrelationFromName(method: String): Correlation = { try { - nameToObjectMap(method) + CorrelationNames.nameToObjectMap(method) } catch { case nse: NoSuchElementException => throw new IllegalArgumentException("Unrecognized method name. Supported correlations: " - + nameToObjectMap.keys.mkString(", ")) + + CorrelationNames.nameToObjectMap.keys.mkString(", ")) } } } + +/** + * Maintains supported and default correlation names. + * + * Currently supported correlations: `pearson`, `spearman`. + * Current default correlation: `pearson`. + * + * After new correlation algorithms are added, please update the documentation here and in + * Statistics.scala for the correlation APIs. + */ +private[mllib] object CorrelationNames { + + // Note: after new types of correlations are implemented, please update this map. + val nameToObjectMap = Map(("pearson", PearsonCorrelation), ("spearman", SpearmanCorrelation)) + val defaultCorrName: String = "pearson" + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/correlation/SpearmanCorrelation.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/correlation/SpearmanCorrelation.scala index 1f7de630e778c..9bd0c2cd05de4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/correlation/SpearmanCorrelation.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/correlation/SpearmanCorrelation.scala @@ -89,20 +89,18 @@ private[stat] object SpearmanCorrelation extends Correlation with Logging { val ranks: RDD[(Long, Double)] = sorted.mapPartitions { iter => // add an extra element to signify the end of the list so that flatMap can flush the last // batch of duplicates - val padded = iter ++ - Iterator[((Double, Long), Long)](((Double.NaN, -1L), -1L)) - var lastVal = 0.0 - var firstRank = 0.0 - val idBuffer = new ArrayBuffer[Long]() + val end = -1L + val padded = iter ++ Iterator[((Double, Long), Long)](((Double.NaN, end), end)) + val firstEntry = padded.next() + var lastVal = firstEntry._1._1 + var firstRank = firstEntry._2.toDouble + val idBuffer = ArrayBuffer(firstEntry._1._2) padded.flatMap { case ((v, id), rank) => - if (v == lastVal && id != Long.MinValue) { + if (v == lastVal && id != end) { idBuffer += id Iterator.empty } else { - val entries = if (idBuffer.size == 0) { - // edge case for the first value matching the initial value of lastVal - Iterator.empty - } else if (idBuffer.size == 1) { + val entries = if (idBuffer.size == 1) { Iterator((idBuffer(0), firstRank)) } else { val averageRank = firstRank + (idBuffer.size - 1.0) / 2.0 diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index ad32e3f4560fe..1d03e6e3b36cf 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -31,8 +31,8 @@ import org.apache.spark.util.random.XORShiftRandom /** * :: Experimental :: - * A class that implements a decision tree algorithm for classification and regression. It - * supports both continuous and categorical features. + * A class which implements a decision tree learning algorithm for classification and regression. + * It supports both continuous and categorical features. * @param strategy The configuration parameters for the tree algorithm which specify the type * of algorithm (classification, regression, etc.), feature type (continuous, * categorical), depth of the tree, quantile calculation strategy, etc. @@ -42,25 +42,25 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo /** * Method to train a decision tree model over an RDD - * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data - * @return a DecisionTreeModel that can be used for prediction + * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] + * @return DecisionTreeModel that can be used for prediction */ def train(input: RDD[LabeledPoint]): DecisionTreeModel = { // Cache input RDD for speedup during multiple passes. - input.cache() + val retaggedInput = input.retag(classOf[LabeledPoint]).cache() logDebug("algo = " + strategy.algo) // Find the splits and the corresponding bins (interval between the splits) using a sample // of the input data. - val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) + val (splits, bins) = DecisionTree.findSplitsBins(retaggedInput, strategy) val numBins = bins(0).length logDebug("numBins = " + numBins) // depth of the decision tree val maxDepth = strategy.maxDepth // the max number of nodes possible given the depth of the tree - val maxNumNodes = math.pow(2, maxDepth).toInt - 1 + val maxNumNodes = math.pow(2, maxDepth + 1).toInt - 1 // Initialize an array to hold filters applied to points for each node. val filters = new Array[List[Filter]](maxNumNodes) // The filter at the top node is an empty list. @@ -70,7 +70,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo // dummy value for top node (updated during first split calculation) val nodes = new Array[Node](maxNumNodes) // num features - val numFeatures = input.take(1)(0).features.size + val numFeatures = retaggedInput.take(1)(0).features.size // Calculate level for single group construction @@ -100,14 +100,14 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo var level = 0 var break = false - while (level < maxDepth && !break) { + while (level <= maxDepth && !break) { logDebug("#####################################") logDebug("level = " + level) logDebug("#####################################") // Find best split for all nodes at a level. - val splitsStatsForLevel = DecisionTree.findBestSplits(input, parentImpurities, + val splitsStatsForLevel = DecisionTree.findBestSplits(retaggedInput, parentImpurities, strategy, level, filters, splits, bins, maxLevelForSingleGroup) for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) { @@ -152,7 +152,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo val split = nodeSplitStats._1 val stats = nodeSplitStats._2 val nodeIndex = math.pow(2, level).toInt - 1 + index - val isLeaf = (stats.gain <= 0) || (level == strategy.maxDepth - 1) + val isLeaf = (stats.gain <= 0) || (level == strategy.maxDepth) val node = new Node(nodeIndex, stats.predict, isLeaf, Some(split), None, None, Some(stats)) logDebug("Node = " + node) nodes(nodeIndex) = node @@ -173,7 +173,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo while (i <= 1) { // Calculate the index of the node from the node level and the index at the current level. val nodeIndex = math.pow(2, level + 1).toInt - 1 + 2 * index + i - if (level < maxDepth - 1) { + if (level < maxDepth) { val impurity = if (i == 0) { nodeSplitStats._2.leftImpurity } else { @@ -197,17 +197,16 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo object DecisionTree extends Serializable with Logging { /** - * Method to train a decision tree model where the instances are represented as an RDD of - * (label, features) pairs. The method supports binary classification and regression. For the - * binary classification, the label for each instance should either be 0 or 1 to denote the two - * classes. The parameters for the algorithm are specified using the strategy parameter. + * Method to train a decision tree model. + * The method supports binary and multiclass classification and regression. * - * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data - * for DecisionTree + * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * For classification, labels should take values {0, 1, ..., numClasses-1}. + * For regression, labels are real numbers. * @param strategy The configuration parameters for the tree algorithm which specify the type * of algorithm (classification, regression, etc.), feature type (continuous, * categorical), depth of the tree, quantile calculation strategy, etc. - * @return a DecisionTreeModel that can be used for prediction + * @return DecisionTreeModel that can be used for prediction */ def train(input: RDD[LabeledPoint], strategy: Strategy): DecisionTreeModel = { new DecisionTree(strategy).train(input) @@ -219,12 +218,14 @@ object DecisionTree extends Serializable with Logging { * binary classification, the label for each instance should either be 0 or 1 to denote the two * classes. * - * @param input input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as - * training data + * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * For classification, labels should take values {0, 1, ..., numClasses-1}. + * For regression, labels are real numbers. * @param algo algorithm, classification or regression * @param impurity impurity criterion used for information gain calculation - * @param maxDepth maxDepth maximum depth of the tree - * @return a DecisionTreeModel that can be used for prediction + * @param maxDepth Maximum depth of the tree. + * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. + * @return DecisionTreeModel that can be used for prediction */ def train( input: RDD[LabeledPoint], @@ -241,13 +242,15 @@ object DecisionTree extends Serializable with Logging { * binary classification, the label for each instance should either be 0 or 1 to denote the two * classes. * - * @param input input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as - * training data + * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * For classification, labels should take values {0, 1, ..., numClasses-1}. + * For regression, labels are real numbers. * @param algo algorithm, classification or regression * @param impurity impurity criterion used for information gain calculation - * @param maxDepth maxDepth maximum depth of the tree + * @param maxDepth Maximum depth of the tree. + * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. * @param numClassesForClassification number of classes for classification. Default value of 2. - * @return a DecisionTreeModel that can be used for prediction + * @return DecisionTreeModel that can be used for prediction */ def train( input: RDD[LabeledPoint], @@ -266,11 +269,13 @@ object DecisionTree extends Serializable with Logging { * 1 to denote the two classes. The method also supports categorical features inputs where the * number of categories can specified using the categoricalFeaturesInfo option. * - * @param input input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as - * training data for DecisionTree + * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * For classification, labels should take values {0, 1, ..., numClasses-1}. + * For regression, labels are real numbers. * @param algo classification or regression * @param impurity criterion used for information gain calculation - * @param maxDepth maximum depth of the tree + * @param maxDepth Maximum depth of the tree. + * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. * @param numClassesForClassification number of classes for classification. Default value of 2. * @param maxBins maximum number of bins used for splitting features * @param quantileCalculationStrategy algorithm for calculating quantiles @@ -279,7 +284,7 @@ object DecisionTree extends Serializable with Logging { * an entry (n -> k) implies the feature n is categorical with k * categories 0, 1, 2, ... , k-1. It's important to note that * features are zero-indexed. - * @return a DecisionTreeModel that can be used for prediction + * @return DecisionTreeModel that can be used for prediction */ def train( input: RDD[LabeledPoint], @@ -301,11 +306,10 @@ object DecisionTree extends Serializable with Logging { * Returns an array of optimal splits for all nodes at a given level. Splits the task into * multiple groups if the level-wise training task could lead to memory overflow. * - * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data - * for DecisionTree + * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] * @param parentImpurities Impurities for all parent nodes for the current level * @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing - * parameters for construction the DecisionTree + * parameters for constructing the DecisionTree * @param level Level of the tree * @param filters Filters for all nodes at a given level * @param splits possible splits for all features @@ -348,11 +352,10 @@ object DecisionTree extends Serializable with Logging { /** * Returns an array of optimal splits for a group of nodes at a given level * - * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data - * for DecisionTree + * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] * @param parentImpurities Impurities for all parent nodes for the current level * @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing - * parameters for construction the DecisionTree + * parameters for constructing the DecisionTree * @param level Level of the tree * @param filters Filters for all nodes at a given level * @param splits possible splits for all features @@ -373,7 +376,7 @@ object DecisionTree extends Serializable with Logging { groupIndex: Int = 0): Array[(Split, InformationGainStats)] = { /* - * The high-level description for the best split optimizations are noted here. + * The high-level descriptions of the best split optimizations are noted here. * * *Level-wise training* * We perform bin calculations for all nodes at the given level to avoid making multiple @@ -396,18 +399,27 @@ object DecisionTree extends Serializable with Logging { * drastically reduce the communication overhead. */ - // common calculations for multiple nested methods + // Common calculations for multiple nested methods: + + // numNodes: Number of nodes in this (level of tree, group), + // where nodes at deeper (larger) levels may be divided into groups. val numNodes = math.pow(2, level).toInt / numGroups logDebug("numNodes = " + numNodes) + // Find the number of features by looking at the first sample. val numFeatures = input.first().features.size logDebug("numFeatures = " + numFeatures) + + // numBins: Number of bins = 1 + number of possible splits val numBins = bins(0).length logDebug("numBins = " + numBins) + val numClasses = strategy.numClassesForClassification logDebug("numClasses = " + numClasses) + val isMulticlassClassification = strategy.isMulticlassClassification logDebug("isMulticlassClassification = " + isMulticlassClassification) + val isMulticlassClassificationWithCategoricalFeatures = strategy.isMulticlassWithCategoricalFeatures logDebug("isMultiClassWithCategoricalFeatures = " + @@ -465,10 +477,13 @@ object DecisionTree extends Serializable with Logging { } /** - * Find bin for one feature. + * Find bin for one (labeledPoint, feature). */ - def findBin(featureIndex: Int, labeledPoint: LabeledPoint, - isFeatureContinuous: Boolean, isSpaceSufficientForAllCategoricalSplits: Boolean): Int = { + def findBin( + featureIndex: Int, + labeledPoint: LabeledPoint, + isFeatureContinuous: Boolean, + isSpaceSufficientForAllCategoricalSplits: Boolean): Int = { val binForFeatures = bins(featureIndex) val feature = labeledPoint.features(featureIndex) @@ -483,7 +498,7 @@ object DecisionTree extends Serializable with Logging { val bin = binForFeatures(mid) val lowThreshold = bin.lowSplit.threshold val highThreshold = bin.highSplit.threshold - if ((lowThreshold < feature) && (highThreshold >= feature)){ + if ((lowThreshold < feature) && (highThreshold >= feature)) { return mid } else if (lowThreshold >= feature) { @@ -507,41 +522,51 @@ object DecisionTree extends Serializable with Logging { } /** - * Sequential search helper method to find bin for categorical feature. + * Sequential search helper method to find bin for categorical feature + * (for classification and regression). */ - def sequentialBinSearchForOrderedCategoricalFeatureInClassification(): Int = { + def sequentialBinSearchForOrderedCategoricalFeature(): Int = { val featureCategories = strategy.categoricalFeaturesInfo(featureIndex) - val numCategoricalBins = math.pow(2.0, featureCategories - 1).toInt - 1 + val featureValue = labeledPoint.features(featureIndex) var binIndex = 0 - while (binIndex < numCategoricalBins) { + while (binIndex < featureCategories) { val bin = bins(featureIndex)(binIndex) val categories = bin.highSplit.categories - val features = labeledPoint.features - if (categories.contains(features(featureIndex))) { + if (categories.contains(featureValue)) { return binIndex } binIndex += 1 } + if (featureValue < 0 || featureValue >= featureCategories) { + throw new IllegalArgumentException( + s"DecisionTree given invalid data:" + + s" Feature $featureIndex is categorical with values in" + + s" {0,...,${featureCategories - 1}," + + s" but a data point gives it value $featureValue.\n" + + " Bad data point: " + labeledPoint.toString) + } -1 } if (isFeatureContinuous) { // Perform binary search for finding bin for continuous features. val binIndex = binarySearchForBins() - if (binIndex == -1){ + if (binIndex == -1) { throw new UnknownError("no bin was found for continuous variable.") } binIndex } else { // Perform sequential search to find bin for categorical features. val binIndex = { - if (isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits) { + val isUnorderedFeature = + isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits + if (isUnorderedFeature) { sequentialBinSearchForUnorderedCategoricalFeatureInClassification() } else { - sequentialBinSearchForOrderedCategoricalFeatureInClassification() + sequentialBinSearchForOrderedCategoricalFeature() } } - if (binIndex == -1){ + if (binIndex == -1) { throw new UnknownError("no bin was found for categorical variable.") } binIndex @@ -555,6 +580,14 @@ object DecisionTree extends Serializable with Logging { * where b_ij is an integer between 0 and numBins - 1 for regressions and binary * classification and the categorical feature value in multiclass classification. * Invalid sample is denoted by noting bin for feature 1 as -1. + * + * For unordered features, the "bin index" returned is actually the feature value (category). + * + * @return Array of size 1 + numFeatures * numNodes, where + * arr(0) = label for labeledPoint, and + * arr(1 + numFeatures * nodeIndex + featureIndex) = + * bin index for this labeledPoint + * (or InvalidBinIndex if labeledPoint is not handled by this node) */ def findBinsForLevel(labeledPoint: LabeledPoint): Array[Double] = { // Calculate bin index and label per feature per node. @@ -598,58 +631,85 @@ object DecisionTree extends Serializable with Logging { // Find feature bins for all nodes at a level. val binMappedRDD = input.map(x => findBinsForLevel(x)) - def updateBinForOrderedFeature(arr: Array[Double], agg: Array[Double], nodeIndex: Int, - label: Double, featureIndex: Int) = { - + /** + * Increment aggregate in location for (node, feature, bin, label). + * + * @param arr Bin mapping from findBinsForLevel. arr(0) stores the class label. + * Array of size 1 + (numFeatures * numNodes). + * @param agg Array storing aggregate calculation, of size: + * numClasses * numBins * numFeatures * numNodes. + * Indexed by (node, feature, bin, label) where label is the least significant bit. + */ + def updateBinForOrderedFeature( + arr: Array[Double], + agg: Array[Double], + nodeIndex: Int, + label: Double, + featureIndex: Int): Unit = { // Find the bin index for this feature. val arrShift = 1 + numFeatures * nodeIndex val arrIndex = arrShift + featureIndex // Update the left or right count for one bin. - val aggShift = numClasses * numBins * numFeatures * nodeIndex - val aggIndex - = aggShift + numClasses * featureIndex * numBins + arr(arrIndex).toInt * numClasses - val labelInt = label.toInt - agg(aggIndex + labelInt) = agg(aggIndex + labelInt) + 1 + val aggIndex = + numClasses * numBins * numFeatures * nodeIndex + + numClasses * numBins * featureIndex + + numClasses * arr(arrIndex).toInt + + label.toInt + agg(aggIndex) += 1 } - def updateBinForUnorderedFeature(nodeIndex: Int, featureIndex: Int, arr: Array[Double], - label: Double, agg: Array[Double], rightChildShift: Int) = { + /** + * Increment aggregate in location for (nodeIndex, featureIndex, [bins], label), + * where [bins] ranges over all bins. + * Updates left or right side of aggregate depending on split. + * + * @param arr arr(0) = label. + * arr(1 + featureIndex + nodeIndex * numFeatures) = feature value (category) + * @param agg Indexed by (left/right, node, feature, bin, label) + * where label is the least significant bit. + * The left/right specifier is a 0/1 index indicating left/right child info. + * @param rightChildShift Offset for right side of agg. + */ + def updateBinForUnorderedFeature( + nodeIndex: Int, + featureIndex: Int, + arr: Array[Double], + label: Double, + agg: Array[Double], + rightChildShift: Int): Unit = { // Find the bin index for this feature. - val arrShift = 1 + numFeatures * nodeIndex - val arrIndex = arrShift + featureIndex + val arrIndex = 1 + numFeatures * nodeIndex + featureIndex + val featureValue = arr(arrIndex).toInt // Update the left or right count for one bin. - val aggShift = numClasses * numBins * numFeatures * nodeIndex - val aggIndex - = aggShift + numClasses * featureIndex * numBins + arr(arrIndex).toInt * numClasses + val aggShift = + numClasses * numBins * numFeatures * nodeIndex + + numClasses * numBins * featureIndex + + label.toInt // Find all matching bins and increment their values val featureCategories = strategy.categoricalFeaturesInfo(featureIndex) val numCategoricalBins = math.pow(2.0, featureCategories - 1).toInt - 1 var binIndex = 0 while (binIndex < numCategoricalBins) { - val labelInt = label.toInt - if (bins(featureIndex)(binIndex).highSplit.categories.contains(labelInt)) { - agg(aggIndex + binIndex) - = agg(aggIndex + binIndex) + 1 + val aggIndex = aggShift + binIndex * numClasses + if (bins(featureIndex)(binIndex).highSplit.categories.contains(featureValue)) { + agg(aggIndex) += 1 } else { - agg(rightChildShift + aggIndex + binIndex) - = agg(rightChildShift + aggIndex + binIndex) + 1 + agg(rightChildShift + aggIndex) += 1 } binIndex += 1 } } /** - * Performs a sequential aggregation over a partition for classification. For l nodes, - * k features, either the left count or the right count of one of the p bins is - * incremented based upon whether the feature is classified as 0 or 1. + * Helper for binSeqOp. * - * @param agg Array[Double] storing aggregate calculation of size - * numClasses * numSplits * numFeatures*numNodes for classification - * @param arr Array[Double] of size 1 + (numFeatures * numNodes) - * @return Array[Double] storing aggregate calculation of size - * 2 * numSplits * numFeatures * numNodes for classification + * @param arr Bin mapping from findBinsForLevel. arr(0) stores the class label. + * Array of size 1 + (numFeatures * numNodes). + * @param agg Array storing aggregate calculation, of size: + * numClasses * numBins * numFeatures * numNodes. + * Indexed by (node, feature, bin, label) where label is the least significant bit. */ - def orderedClassificationBinSeqOp(arr: Array[Double], agg: Array[Double]) = { + def binaryOrNotCategoricalBinSeqOp(arr: Array[Double], agg: Array[Double]): Unit = { // Iterate over all nodes. var nodeIndex = 0 while (nodeIndex < numNodes) { @@ -671,17 +731,21 @@ object DecisionTree extends Serializable with Logging { } /** - * Performs a sequential aggregation over a partition for classification. For l nodes, - * k features, either the left count or the right count of one of the p bins is - * incremented based upon whether the feature is classified as 0 or 1. + * Helper for binSeqOp. * - * @param agg Array[Double] storing aggregate calculation of size - * numClasses * numSplits * numFeatures*numNodes for classification - * @param arr Array[Double] of size 1 + (numFeatures * numNodes) - * @return Array[Double] storing aggregate calculation of size - * 2 * numClasses * numSplits * numFeatures * numNodes for classification + * @param arr Bin mapping from findBinsForLevel. arr(0) stores the class label. + * Array of size 1 + (numFeatures * numNodes). + * For ordered features, + * arr(1 + featureIndex + nodeIndex * numFeatures) = bin index. + * For unordered features, + * arr(1 + featureIndex + nodeIndex * numFeatures) = feature value (category). + * @param agg Array storing aggregate calculation. + * For ordered features, this is of size: + * numClasses * numBins * numFeatures * numNodes. + * For unordered features, this is of size: + * 2 * numClasses * numBins * numFeatures * numNodes. */ - def unorderedClassificationBinSeqOp(arr: Array[Double], agg: Array[Double]) = { + def multiclassWithCategoricalBinSeqOp(arr: Array[Double], agg: Array[Double]): Unit = { // Iterate over all nodes. var nodeIndex = 0 while (nodeIndex < numNodes) { @@ -717,16 +781,17 @@ object DecisionTree extends Serializable with Logging { } /** - * Performs a sequential aggregation over a partition for regression. For l nodes, k features, + * Performs a sequential aggregation over a partition for regression. + * For l nodes, k features, * the count, sum, sum of squares of one of the p bins is incremented. * - * @param agg Array[Double] storing aggregate calculation of size - * 3 * numSplits * numFeatures * numNodes for classification - * @param arr Array[Double] of size 1 + (numFeatures * numNodes) - * @return Array[Double] storing aggregate calculation of size - * 3 * numSplits * numFeatures * numNodes for regression + * @param agg Array storing aggregate calculation, updated by this function. + * Size: 3 * numBins * numFeatures * numNodes + * @param arr Bin mapping from findBinsForLevel. + * Array of size 1 + (numFeatures * numNodes). + * @return agg */ - def regressionBinSeqOp(arr: Array[Double], agg: Array[Double]) = { + def regressionBinSeqOp(arr: Array[Double], agg: Array[Double]): Unit = { // Iterate over all nodes. var nodeIndex = 0 while (nodeIndex < numNodes) { @@ -757,14 +822,30 @@ object DecisionTree extends Serializable with Logging { /** * Performs a sequential aggregation over a partition. + * For l nodes, k features, + * For classification: + * Either the left count or the right count of one of the bins is + * incremented based upon whether the feature is classified as 0 or 1. + * For regression: + * The count, sum, sum of squares of one of the bins is incremented. + * + * @param agg Array storing aggregate calculation, updated by this function. + * Size for classification: + * numClasses * numBins * numFeatures * numNodes for ordered features, or + * 2 * numClasses * numBins * numFeatures * numNodes for unordered features. + * Size for regression: + * 3 * numBins * numFeatures * numNodes. + * @param arr Bin mapping from findBinsForLevel. + * Array of size 1 + (numFeatures * numNodes). + * @return agg */ def binSeqOp(agg: Array[Double], arr: Array[Double]): Array[Double] = { strategy.algo match { case Classification => if(isMulticlassClassificationWithCategoricalFeatures) { - unorderedClassificationBinSeqOp(arr, agg) + multiclassWithCategoricalBinSeqOp(arr, agg) } else { - orderedClassificationBinSeqOp(arr, agg) + binaryOrNotCategoricalBinSeqOp(arr, agg) } case Regression => regressionBinSeqOp(arr, agg) } @@ -815,20 +896,10 @@ object DecisionTree extends Serializable with Logging { topImpurity: Double): InformationGainStats = { strategy.algo match { case Classification => - var classIndex = 0 - val leftCounts: Array[Double] = new Array[Double](numClasses) - val rightCounts: Array[Double] = new Array[Double](numClasses) - var leftTotalCount = 0.0 - var rightTotalCount = 0.0 - while (classIndex < numClasses) { - val leftClassCount = leftNodeAgg(featureIndex)(splitIndex)(classIndex) - val rightClassCount = rightNodeAgg(featureIndex)(splitIndex)(classIndex) - leftCounts(classIndex) = leftClassCount - leftTotalCount += leftClassCount - rightCounts(classIndex) = rightClassCount - rightTotalCount += rightClassCount - classIndex += 1 - } + val leftCounts: Array[Double] = leftNodeAgg(featureIndex)(splitIndex) + val rightCounts: Array[Double] = rightNodeAgg(featureIndex)(splitIndex) + val leftTotalCount = leftCounts.sum + val rightTotalCount = rightCounts.sum val impurity = { if (level > 0) { @@ -845,33 +916,17 @@ object DecisionTree extends Serializable with Logging { } } - if (leftTotalCount == 0) { - return new InformationGainStats(0, topImpurity, topImpurity, Double.MinValue, 1) - } - if (rightTotalCount == 0) { - return new InformationGainStats(0, topImpurity, Double.MinValue, topImpurity, 1) - } - - val leftImpurity = strategy.impurity.calculate(leftCounts, leftTotalCount) - val rightImpurity = strategy.impurity.calculate(rightCounts, rightTotalCount) - - val leftWeight = leftTotalCount / (leftTotalCount + rightTotalCount) - val rightWeight = rightTotalCount / (leftTotalCount + rightTotalCount) - - val gain = { - if (level > 0) { - impurity - leftWeight * leftImpurity - rightWeight * rightImpurity - } else { - impurity - leftWeight * leftImpurity - rightWeight * rightImpurity - } - } - val totalCount = leftTotalCount + rightTotalCount + if (totalCount == 0) { + // Return arbitrary prediction. + return new InformationGainStats(0, topImpurity, topImpurity, topImpurity, 0) + } // Sum of count for each label - val leftRightCounts: Array[Double] - = leftCounts.zip(rightCounts) - .map{case (leftCount, rightCount) => leftCount + rightCount} + val leftRightCounts: Array[Double] = + leftCounts.zip(rightCounts).map { case (leftCount, rightCount) => + leftCount + rightCount + } def indexOfLargestArrayElement(array: Array[Double]): Int = { val result = array.foldLeft(-1, Double.MinValue, 0) { @@ -885,6 +940,22 @@ object DecisionTree extends Serializable with Logging { val predict = indexOfLargestArrayElement(leftRightCounts) val prob = leftRightCounts(predict) / totalCount + val leftImpurity = if (leftTotalCount == 0) { + topImpurity + } else { + strategy.impurity.calculate(leftCounts, leftTotalCount) + } + val rightImpurity = if (rightTotalCount == 0) { + topImpurity + } else { + strategy.impurity.calculate(rightCounts, rightTotalCount) + } + + val leftWeight = leftTotalCount / totalCount + val rightWeight = rightTotalCount / totalCount + + val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity + new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict, prob) case Regression => val leftCount = leftNodeAgg(featureIndex)(splitIndex)(0) @@ -937,10 +1008,18 @@ object DecisionTree extends Serializable with Logging { /** * Extracts left and right split aggregates. - * @param binData Array[Double] of size 2*numFeatures*numSplits - * @return (leftNodeAgg, rightNodeAgg) tuple of type (Array[Array[Array[Double\]\]\], - * Array[Array[Array[Double\]\]\]) where each array is of size(numFeature, - * (numBins - 1), numClasses) + * @param binData Aggregate array slice from getBinDataForNode. + * For classification: + * For unordered features, this is leftChildData ++ rightChildData, + * each of which is indexed by (feature, split/bin, class), + * with class being the least significant bit. + * For ordered features, this is of size numClasses * numBins * numFeatures. + * For regression: + * This is of size 2 * numFeatures * numBins. + * @return (leftNodeAgg, rightNodeAgg) pair of arrays. + * For classification, each array is of size (numFeatures, (numBins - 1), numClasses). + * For regression, each array is of size (numFeatures, (numBins - 1), 3). + * */ def extractLeftRightNodeAggregates( binData: Array[Double]): (Array[Array[Array[Double]]], Array[Array[Array[Double]]]) = { @@ -983,6 +1062,11 @@ object DecisionTree extends Serializable with Logging { } } + /** + * Reshape binData for this feature. + * Indexes binData as (feature, split, class) with class as the least significant bit. + * @param leftNodeAgg leftNodeAgg(featureIndex)(splitIndex)(classIndex) = aggregate value + */ def findAggForUnorderedFeatureClassification( leftNodeAgg: Array[Array[Array[Double]]], rightNodeAgg: Array[Array[Array[Double]]], @@ -1052,7 +1136,7 @@ object DecisionTree extends Serializable with Logging { val rightNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, numClasses) var featureIndex = 0 while (featureIndex < numFeatures) { - if (isMulticlassClassificationWithCategoricalFeatures){ + if (isMulticlassClassificationWithCategoricalFeatures) { val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty if (isFeatureContinuous) { findAggForOrderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex) @@ -1107,7 +1191,7 @@ object DecisionTree extends Serializable with Logging { /** * Find the best split for a node. - * @param binData Array[Double] of size 2 * numSplits * numFeatures + * @param binData Bin data slice for this node, given by getBinDataForNode. * @param nodeImpurity impurity of the top node * @return tuple of split and information gain */ @@ -1133,7 +1217,7 @@ object DecisionTree extends Serializable with Logging { while (featureIndex < numFeatures) { // Iterate over all splits. var splitIndex = 0 - val maxSplitIndex : Double = { + val maxSplitIndex: Double = { val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty if (isFeatureContinuous) { numBins - 1 @@ -1162,8 +1246,8 @@ object DecisionTree extends Serializable with Logging { (bestFeatureIndex, bestSplitIndex, bestGainStats) } + logDebug("best split = " + splits(bestFeatureIndex)(bestSplitIndex)) logDebug("best split bin = " + bins(bestFeatureIndex)(bestSplitIndex)) - logDebug("best split bin = " + splits(bestFeatureIndex)(bestSplitIndex)) (splits(bestFeatureIndex)(bestSplitIndex), gainStats) } @@ -1214,8 +1298,17 @@ object DecisionTree extends Serializable with Logging { bestSplits } - private def getElementsPerNode(numFeatures: Int, numBins: Int, numClasses: Int, - isMulticlassClassificationWithCategoricalFeatures: Boolean, algo: Algo): Int = { + /** + * Get the number of values to be stored per node in the bin aggregates. + * + * @param numBins Number of bins = 1 + number of possible splits. + */ + private def getElementsPerNode( + numFeatures: Int, + numBins: Int, + numClasses: Int, + isMulticlassClassificationWithCategoricalFeatures: Boolean, + algo: Algo): Int = { algo match { case Classification => if (isMulticlassClassificationWithCategoricalFeatures) { @@ -1228,18 +1321,40 @@ object DecisionTree extends Serializable with Logging { } /** - * Returns split and bins for decision tree calculation. - * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data - * for DecisionTree + * Returns splits and bins for decision tree calculation. + * Continuous and categorical features are handled differently. + * + * Continuous features: + * For each feature, there are numBins - 1 possible splits representing the possible binary + * decisions at each node in the tree. + * + * Categorical features: + * For each feature, there is 1 bin per split. + * Splits and bins are handled in 2 ways: + * (a) For multiclass classification with a low-arity feature + * (i.e., if isMulticlass && isSpaceSufficientForAllCategoricalSplits), + * the feature is split based on subsets of categories. + * There are 2^(maxFeatureValue - 1) - 1 splits. + * (b) For regression and binary classification, + * and for multiclass classification with a high-arity feature, + * there is one split per category. + + * Categorical case (a) features are called unordered features. + * Other cases are called ordered features. + * + * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] * @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing - * parameters for construction the DecisionTree - * @return a tuple of (splits,bins) where splits is an Array of [org.apache.spark.mllib.tree - * .model.Split] of size (numFeatures, numSplits-1) and bins is an Array of [org.apache - * .spark.mllib.tree.model.Bin] of size (numFeatures, numSplits1) + * parameters for construction the DecisionTree + * @return A tuple of (splits,bins). + * Splits is an Array of [[org.apache.spark.mllib.tree.model.Split]] + * of size (numFeatures, numBins - 1). + * Bins is an Array of [[org.apache.spark.mllib.tree.model.Bin]] + * of size (numFeatures, numBins). */ protected[tree] def findSplitsBins( input: RDD[LabeledPoint], strategy: Strategy): (Array[Array[Split]], Array[Array[Bin]]) = { + val count = input.count() // Find the number of features by looking at the first sample @@ -1271,7 +1386,8 @@ object DecisionTree extends Serializable with Logging { logDebug("fraction of data used for calculating quantiles = " + fraction) // sampled input for RDD calculation - val sampledInput = input.sample(false, fraction, new XORShiftRandom().nextInt()).collect() + val sampledInput = + input.sample(withReplacement = false, fraction, new XORShiftRandom().nextInt()).collect() val numSamples = sampledInput.length val stride: Double = numSamples.toDouble / numBins @@ -1286,7 +1402,7 @@ object DecisionTree extends Serializable with Logging { // Iterate over all features. var featureIndex = 0 - while (featureIndex < numFeatures){ + while (featureIndex < numFeatures) { // Check whether the feature is continuous. val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty if (isFeatureContinuous) { @@ -1294,8 +1410,10 @@ object DecisionTree extends Serializable with Logging { val stride: Double = numSamples.toDouble / numBins logDebug("stride = " + stride) for (index <- 0 until numBins - 1) { - val sampleIndex = (index + 1) * stride.toInt - val split = new Split(featureIndex, featureSamples(sampleIndex), Continuous, List()) + val sampleIndex = index * stride.toInt + // Set threshold halfway in between 2 samples. + val threshold = (featureSamples(sampleIndex) + featureSamples(sampleIndex + 1)) / 2.0 + val split = new Split(featureIndex, threshold, Continuous, List()) splits(featureIndex)(index) = split } } else { // Categorical feature @@ -1304,8 +1422,10 @@ object DecisionTree extends Serializable with Logging { = numBins > math.pow(2, featureCategories.toInt - 1) - 1 // Use different bin/split calculation strategy for categorical features in multiclass - // classification that satisfy the space constraint - if (isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits) { + // classification that satisfy the space constraint. + val isUnorderedFeature = + isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits + if (isUnorderedFeature) { // 2^(maxFeatureValue- 1) - 1 combinations var index = 0 while (index < math.pow(2.0, featureCategories - 1).toInt - 1) { @@ -1330,8 +1450,13 @@ object DecisionTree extends Serializable with Logging { } index += 1 } - } else { - + } else { // ordered feature + /* For a given categorical feature, use a subsample of the data + * to choose how to arrange possible splits. + * This examines each category and computes a centroid. + * These centroids are later used to sort the possible splits. + * centroidForCategories is a mapping: category (for the given feature) --> centroid + */ val centroidForCategories = { if (isMulticlassClassification) { // For categorical variables in multiclass classification, @@ -1341,7 +1466,7 @@ object DecisionTree extends Serializable with Logging { .groupBy(_._1) .mapValues(x => x.groupBy(_._2).mapValues(x => x.size.toDouble)) .map(x => (x._1, x._2.values.toArray)) - .map(x => (x._1, strategy.impurity.calculate(x._2,x._2.sum))) + .map(x => (x._1, strategy.impurity.calculate(x._2, x._2.sum))) } else { // regression or binary classification // For categorical variables in regression and binary classification, // each bin is a category. The bins are sorted and they @@ -1352,7 +1477,7 @@ object DecisionTree extends Serializable with Logging { } } - logDebug("centriod for categories = " + centroidForCategories.mkString(",")) + logDebug("centroid for categories = " + centroidForCategories.mkString(",")) // Check for missing categorical variables and putting them last in the sorted list. val fullCentroidForCategories = scala.collection.mutable.Map[Double,Double]() @@ -1367,7 +1492,7 @@ object DecisionTree extends Serializable with Logging { // bins sorted by centroids val categoriesSortedByCentroid = fullCentroidForCategories.toList.sortBy(_._2) - logDebug("centriod for categorical variable = " + categoriesSortedByCentroid) + logDebug("centroid for categorical variable = " + categoriesSortedByCentroid) var categoriesForSplit = List[Double]() categoriesSortedByCentroid.iterator.zipWithIndex.foreach { @@ -1397,7 +1522,7 @@ object DecisionTree extends Serializable with Logging { if (isFeatureContinuous) { // Bins for categorical variables are already assigned. bins(featureIndex)(0) = new Bin(new DummyLowSplit(featureIndex, Continuous), splits(featureIndex)(0), Continuous, Double.MinValue) - for (index <- 1 until numBins - 1){ + for (index <- 1 until numBins - 1) { val bin = new Bin(splits(featureIndex)(index-1), splits(featureIndex)(index), Continuous, Double.MinValue) bins(featureIndex)(index) = bin diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index 7c027ac2fda6b..4ee4bcd0bcbc7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -17,6 +17,8 @@ package org.apache.spark.mllib.tree.configuration +import scala.collection.JavaConverters._ + import org.apache.spark.annotation.Experimental import org.apache.spark.mllib.tree.impurity.Impurity import org.apache.spark.mllib.tree.configuration.Algo._ @@ -27,7 +29,8 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ * Stores all the configuration options for tree construction * @param algo classification or regression * @param impurity criterion used for information gain calculation - * @param maxDepth maximum depth of the tree + * @param maxDepth Maximum depth of the tree. + * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. * @param numClassesForClassification number of classes for classification. Default value is 2 * leads to binary classification * @param maxBins maximum number of bins used for splitting features @@ -52,9 +55,39 @@ class Strategy ( val categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](), val maxMemoryInMB: Int = 128) extends Serializable { - require(numClassesForClassification >= 2) - val isMulticlassClassification = numClassesForClassification > 2 + if (algo == Classification) { + require(numClassesForClassification >= 2) + } + val isMulticlassClassification = + algo == Classification && numClassesForClassification > 2 val isMulticlassWithCategoricalFeatures = isMulticlassClassification && (categoricalFeaturesInfo.size > 0) + /** + * Java-friendly constructor. + * + * @param algo classification or regression + * @param impurity criterion used for information gain calculation + * @param maxDepth Maximum depth of the tree. + * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. + * @param numClassesForClassification number of classes for classification. Default value is 2 + * leads to binary classification + * @param maxBins maximum number of bins used for splitting features + * @param categoricalFeaturesInfo A map storing information about the categorical variables and + * the number of discrete values they take. For example, an entry + * (n -> k) implies the feature n is categorical with k categories + * 0, 1, 2, ... , k-1. It's important to note that features are + * zero-indexed. + */ + def this( + algo: Algo, + impurity: Impurity, + maxDepth: Int, + numClassesForClassification: Int, + maxBins: Int, + categoricalFeaturesInfo: java.util.Map[java.lang.Integer, java.lang.Integer]) { + this(algo, impurity, maxDepth, numClassesForClassification, maxBins, Sort, + categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap) + } + } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala index a0e2d91762782..96d2471e1f88c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala @@ -34,10 +34,13 @@ object Entropy extends Impurity { * information calculation for multiclass classification * @param counts Array[Double] with counts for each label * @param totalCount sum of counts for all labels - * @return information value + * @return information value, or 0 if totalCount = 0 */ @DeveloperApi override def calculate(counts: Array[Double], totalCount: Double): Double = { + if (totalCount == 0) { + return 0 + } val numClasses = counts.length var impurity = 0.0 var classIndex = 0 @@ -58,8 +61,16 @@ object Entropy extends Impurity { * @param count number of instances * @param sum sum of labels * @param sumSquares summation of squares of the labels + * @return information value, or 0 if count = 0 */ @DeveloperApi override def calculate(count: Double, sum: Double, sumSquares: Double): Double = throw new UnsupportedOperationException("Entropy.calculate") + + /** + * Get this impurity instance. + * This is useful for passing impurity parameters to a Strategy in Java. + */ + def instance = this + } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala index 48144b5e6d1e4..d586f449048bb 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala @@ -33,10 +33,13 @@ object Gini extends Impurity { * information calculation for multiclass classification * @param counts Array[Double] with counts for each label * @param totalCount sum of counts for all labels - * @return information value + * @return information value, or 0 if totalCount = 0 */ @DeveloperApi override def calculate(counts: Array[Double], totalCount: Double): Double = { + if (totalCount == 0) { + return 0 + } val numClasses = counts.length var impurity = 1.0 var classIndex = 0 @@ -54,8 +57,16 @@ object Gini extends Impurity { * @param count number of instances * @param sum sum of labels * @param sumSquares summation of squares of the labels + * @return information value, or 0 if count = 0 */ @DeveloperApi override def calculate(count: Double, sum: Double, sumSquares: Double): Double = throw new UnsupportedOperationException("Gini.calculate") + + /** + * Get this impurity instance. + * This is useful for passing impurity parameters to a Strategy in Java. + */ + def instance = this + } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala index 7b2a9320cc21d..92b0c7b4a6fbc 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala @@ -31,7 +31,7 @@ trait Impurity extends Serializable { * information calculation for multiclass classification * @param counts Array[Double] with counts for each label * @param totalCount sum of counts for all labels - * @return information value + * @return information value, or 0 if totalCount = 0 */ @DeveloperApi def calculate(counts: Array[Double], totalCount: Double): Double @@ -42,7 +42,7 @@ trait Impurity extends Serializable { * @param count number of instances * @param sum sum of labels * @param sumSquares summation of squares of the labels - * @return information value + * @return information value, or 0 if count = 0 */ @DeveloperApi def calculate(count: Double, sum: Double, sumSquares: Double): Double diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala index 97149a99ead59..f7d99a40eb380 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala @@ -31,7 +31,7 @@ object Variance extends Impurity { * information calculation for multiclass classification * @param counts Array[Double] with counts for each label * @param totalCount sum of counts for all labels - * @return information value + * @return information value, or 0 if totalCount = 0 */ @DeveloperApi override def calculate(counts: Array[Double], totalCount: Double): Double = @@ -43,10 +43,21 @@ object Variance extends Impurity { * @param count number of instances * @param sum sum of labels * @param sumSquares summation of squares of the labels + * @return information value, or 0 if count = 0 */ @DeveloperApi override def calculate(count: Double, sum: Double, sumSquares: Double): Double = { + if (count == 0) { + return 0 + } val squaredLoss = sumSquares - (sum * sum) / count squaredLoss / count } + + /** + * Get this impurity instance. + * This is useful for passing impurity parameters to a Strategy in Java. + */ + def instance = this + } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala index bf692ca8c4bd7..3d3406b5d5f22 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala @@ -24,7 +24,8 @@ import org.apache.spark.mllib.linalg.Vector /** * :: Experimental :: - * Model to store the decision tree parameters + * Decision tree model for classification or regression. + * This model stores the decision tree structure and parameters. * @param topNode root node * @param algo algorithm type -- classification or regression */ @@ -50,4 +51,32 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable def predict(features: RDD[Vector]): RDD[Double] = { features.map(x => predict(x)) } + + /** + * Get number of nodes in tree, including leaf nodes. + */ + def numNodes: Int = { + 1 + topNode.numDescendants + } + + /** + * Get depth of tree. + * E.g.: Depth 0 means 1 leaf node. Depth 1 means 1 internal node and 2 leaf nodes. + */ + def depth: Int = { + topNode.subtreeDepth + } + + /** + * Print full model. + */ + override def toString: String = algo match { + case Classification => + s"DecisionTreeModel classifier\n" + topNode.subtreeToString(2) + case Regression => + s"DecisionTreeModel regressor\n" + topNode.subtreeToString(2) + case _ => throw new IllegalArgumentException( + s"DecisionTreeModel given unknown algo parameter: $algo.") + } + } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala index 682f213f411a7..944f11c2c2e4f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala @@ -91,4 +91,60 @@ class Node ( } } } + + /** + * Get the number of nodes in tree below this node, including leaf nodes. + * E.g., if this is a leaf, returns 0. If both children are leaves, returns 2. + */ + private[tree] def numDescendants: Int = { + if (isLeaf) { + 0 + } else { + 2 + leftNode.get.numDescendants + rightNode.get.numDescendants + } + } + + /** + * Get depth of tree from this node. + * E.g.: Depth 0 means this is a leaf node. + */ + private[tree] def subtreeDepth: Int = { + if (isLeaf) { + 0 + } else { + 1 + math.max(leftNode.get.subtreeDepth, rightNode.get.subtreeDepth) + } + } + + /** + * Recursive print function. + * @param indentFactor The number of spaces to add to each level of indentation. + */ + private[tree] def subtreeToString(indentFactor: Int = 0): String = { + + def splitToString(split: Split, left: Boolean): String = { + split.featureType match { + case Continuous => if (left) { + s"(feature ${split.feature} <= ${split.threshold})" + } else { + s"(feature ${split.feature} > ${split.threshold})" + } + case Categorical => if (left) { + s"(feature ${split.feature} in ${split.categories.mkString("{",",","}")})" + } else { + s"(feature ${split.feature} not in ${split.categories.mkString("{",",","}")})" + } + } + } + val prefix: String = " " * indentFactor + if (isLeaf) { + prefix + s"Predict: $predict\n" + } else { + prefix + s"If ${splitToString(split.get, left=true)}\n" + + leftNode.get.subtreeToString(indentFactor + 1) + + prefix + s"Else ${splitToString(split.get, left=false)}\n" + + rightNode.get.subtreeToString(indentFactor + 1) + } + } + } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/LabelParsers.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/LabelParsers.scala deleted file mode 100644 index e25bf18b780bf..0000000000000 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/LabelParsers.scala +++ /dev/null @@ -1,56 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.mllib.util - -/** Trait for label parsers. */ -private trait LabelParser extends Serializable { - /** Parses a string label into a double label. */ - def parse(labelString: String): Double -} - -/** Factory methods for label parsers. */ -private object LabelParser { - def getInstance(multiclass: Boolean): LabelParser = { - if (multiclass) MulticlassLabelParser else BinaryLabelParser - } -} - -/** - * Label parser for binary labels, which outputs 1.0 (positive) if the value is greater than 0.5, - * or 0.0 (negative) otherwise. So it works with +1/-1 labeling and +1/0 labeling. - */ -private object BinaryLabelParser extends LabelParser { - /** Gets the default instance of BinaryLabelParser. */ - def getInstance(): LabelParser = this - - /** - * Parses the input label into positive (1.0) if the value is greater than 0.5, - * or negative (0.0) otherwise. - */ - override def parse(labelString: String): Double = if (labelString.toDouble > 0.5) 1.0 else 0.0 -} - -/** - * Label parser for multiclass labels, which converts the input label to double. - */ -private object MulticlassLabelParser extends LabelParser { - /** Gets the default instance of MulticlassLabelParser. */ - def getInstance(): LabelParser = this - - override def parse(labelString: String): Double = labelString.toDouble -} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala index 30de24ad89f98..f4cce86a65ba7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala @@ -30,6 +30,8 @@ import org.apache.spark.util.random.BernoulliSampler import org.apache.spark.mllib.regression.{LabeledPointParser, LabeledPoint} import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming.StreamingContext +import org.apache.spark.streaming.dstream.DStream /** * Helper methods to load, save and pre-process data used in ML Lib. @@ -55,7 +57,6 @@ object MLUtils { * * @param sc Spark context * @param path file or directory path in any Hadoop-supported file system URI - * @param labelParser parser for labels * @param numFeatures number of features, which will be determined from the input data if a * nonpositive value is given. This is useful when the dataset is already split * into multiple files and you want to load them separately, because some @@ -64,10 +65,9 @@ object MLUtils { * @param minPartitions min number of partitions * @return labeled data stored as an RDD[LabeledPoint] */ - private def loadLibSVMFile( + def loadLibSVMFile( sc: SparkContext, path: String, - labelParser: LabelParser, numFeatures: Int, minPartitions: Int): RDD[LabeledPoint] = { val parsed = sc.textFile(path, minPartitions) @@ -75,7 +75,7 @@ object MLUtils { .filter(line => !(line.isEmpty || line.startsWith("#"))) .map { line => val items = line.split(' ') - val label = labelParser.parse(items.head) + val label = items.head.toDouble val (indices, values) = items.tail.map { item => val indexAndValue = item.split(':') val index = indexAndValue(0).toInt - 1 // Convert 1-based indices to 0-based. @@ -102,64 +102,46 @@ object MLUtils { // Convenient methods for `loadLibSVMFile`. - /** - * Loads labeled data in the LIBSVM format into an RDD[LabeledPoint]. - * The LIBSVM format is a text-based format used by LIBSVM and LIBLINEAR. - * Each line represents a labeled sparse feature vector using the following format: - * {{{label index1:value1 index2:value2 ...}}} - * where the indices are one-based and in ascending order. - * This method parses each line into a [[org.apache.spark.mllib.regression.LabeledPoint]], - * where the feature indices are converted to zero-based. - * - * @param sc Spark context - * @param path file or directory path in any Hadoop-supported file system URI - * @param multiclass whether the input labels contain more than two classes. If false, any label - * with value greater than 0.5 will be mapped to 1.0, or 0.0 otherwise. So it - * works for both +1/-1 and 1/0 cases. If true, the double value parsed directly - * from the label string will be used as the label value. - * @param numFeatures number of features, which will be determined from the input data if a - * nonpositive value is given. This is useful when the dataset is already split - * into multiple files and you want to load them separately, because some - * features may not present in certain files, which leads to inconsistent - * feature dimensions. - * @param minPartitions min number of partitions - * @return labeled data stored as an RDD[LabeledPoint] - */ - def loadLibSVMFile( + @deprecated("use method without multiclass argument, which no longer has effect", "1.1.0") + def loadLibSVMFile( sc: SparkContext, path: String, multiclass: Boolean, numFeatures: Int, minPartitions: Int): RDD[LabeledPoint] = - loadLibSVMFile(sc, path, LabelParser.getInstance(multiclass), numFeatures, minPartitions) + loadLibSVMFile(sc, path, numFeatures, minPartitions) /** * Loads labeled data in the LIBSVM format into an RDD[LabeledPoint], with the default number of * partitions. */ + def loadLibSVMFile( + sc: SparkContext, + path: String, + numFeatures: Int): RDD[LabeledPoint] = + loadLibSVMFile(sc, path, numFeatures, sc.defaultMinPartitions) + + @deprecated("use method without multiclass argument, which no longer has effect", "1.1.0") def loadLibSVMFile( sc: SparkContext, path: String, multiclass: Boolean, numFeatures: Int): RDD[LabeledPoint] = - loadLibSVMFile(sc, path, multiclass, numFeatures, sc.defaultMinPartitions) + loadLibSVMFile(sc, path, numFeatures) - /** - * Loads labeled data in the LIBSVM format into an RDD[LabeledPoint], with the number of features - * determined automatically and the default number of partitions. - */ + @deprecated("use method without multiclass argument, which no longer has effect", "1.1.0") def loadLibSVMFile( sc: SparkContext, path: String, multiclass: Boolean): RDD[LabeledPoint] = - loadLibSVMFile(sc, path, multiclass, -1, sc.defaultMinPartitions) + loadLibSVMFile(sc, path) /** * Loads binary labeled data in the LIBSVM format into an RDD[LabeledPoint], with number of * features determined automatically and the default number of partitions. */ def loadLibSVMFile(sc: SparkContext, path: String): RDD[LabeledPoint] = - loadLibSVMFile(sc, path, multiclass = false, -1, sc.defaultMinPartitions) + loadLibSVMFile(sc, path, -1) /** * Save labeled data in LIBSVM format. @@ -212,6 +194,19 @@ object MLUtils { def loadLabeledPoints(sc: SparkContext, dir: String): RDD[LabeledPoint] = loadLabeledPoints(sc, dir, sc.defaultMinPartitions) + /** + * Loads streaming labeled points from a stream of text files + * where points are in the same format as used in `RDD[LabeledPoint].saveAsTextFile`. + * See `StreamingContext.textFileStream` for more details on how to + * generate a stream from files + * + * @param ssc Streaming context + * @param dir Directory path in any Hadoop-supported file system URI + * @return Labeled points stored as a DStream[LabeledPoint] + */ + def loadStreamingLabeledPoints(ssc: StreamingContext, dir: String): DStream[LabeledPoint] = + ssc.textFileStream(dir).map(LabeledPointParser.parse) + /** * Load labeled data from a file. The data format used here is * , ... diff --git a/mllib/src/test/java/org/apache/spark/mllib/feature/JavaTfIdfSuite.java b/mllib/src/test/java/org/apache/spark/mllib/feature/JavaTfIdfSuite.java new file mode 100644 index 0000000000000..e8d99f4ae43ae --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/mllib/feature/JavaTfIdfSuite.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.feature; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.List; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import com.google.common.collect.Lists; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.linalg.Vector; + +public class JavaTfIdfSuite implements Serializable { + private transient JavaSparkContext sc; + + @Before + public void setUp() { + sc = new JavaSparkContext("local", "JavaTfIdfSuite"); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + } + + @Test + public void tfIdf() { + // The tests are to check Java compatibility. + HashingTF tf = new HashingTF(); + JavaRDD> documents = sc.parallelize(Lists.newArrayList( + Lists.newArrayList("this is a sentence".split(" ")), + Lists.newArrayList("this is another sentence".split(" ")), + Lists.newArrayList("this is still a sentence".split(" "))), 2); + JavaRDD termFreqs = tf.transform(documents); + termFreqs.collect(); + IDF idf = new IDF(); + JavaRDD tfIdfs = idf.fit(termFreqs).transform(termFreqs); + List localTfIdfs = tfIdfs.collect(); + int indexOfThis = tf.indexOf("this"); + for (Vector v: localTfIdfs) { + Assert.assertEquals(0.0, v.apply(indexOfThis), 1e-15); + } + } +} diff --git a/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java b/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java index bf2365f82044c..f6ca9643227f8 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java @@ -20,6 +20,11 @@ import java.io.Serializable; import java.util.List; +import scala.Tuple2; +import scala.Tuple3; + +import org.jblas.DoubleMatrix; + import org.junit.After; import org.junit.Assert; import org.junit.Before; @@ -28,8 +33,6 @@ import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; -import org.jblas.DoubleMatrix; - public class JavaALSSuite implements Serializable { private transient JavaSparkContext sc; @@ -44,21 +47,28 @@ public void tearDown() { sc = null; } - static void validatePrediction(MatrixFactorizationModel model, int users, int products, int features, - DoubleMatrix trueRatings, double matchThreshold, boolean implicitPrefs, DoubleMatrix truePrefs) { + static void validatePrediction( + MatrixFactorizationModel model, + int users, + int products, + int features, + DoubleMatrix trueRatings, + double matchThreshold, + boolean implicitPrefs, + DoubleMatrix truePrefs) { DoubleMatrix predictedU = new DoubleMatrix(users, features); - List> userFeatures = model.userFeatures().toJavaRDD().collect(); + List> userFeatures = model.userFeatures().toJavaRDD().collect(); for (int i = 0; i < features; ++i) { - for (scala.Tuple2 userFeature : userFeatures) { + for (Tuple2 userFeature : userFeatures) { predictedU.put((Integer)userFeature._1(), i, userFeature._2()[i]); } } DoubleMatrix predictedP = new DoubleMatrix(products, features); - List> productFeatures = + List> productFeatures = model.productFeatures().toJavaRDD().collect(); for (int i = 0; i < features; ++i) { - for (scala.Tuple2 productFeature : productFeatures) { + for (Tuple2 productFeature : productFeatures) { predictedP.put((Integer)productFeature._1(), i, productFeature._2()[i]); } } @@ -75,7 +85,8 @@ static void validatePrediction(MatrixFactorizationModel model, int users, int pr } } } else { - // For implicit prefs we use the confidence-weighted RMSE to test (ref Mahout's implicit ALS tests) + // For implicit prefs we use the confidence-weighted RMSE to test + // (ref Mahout's implicit ALS tests) double sqErr = 0.0; double denom = 0.0; for (int u = 0; u < users; ++u) { @@ -100,7 +111,7 @@ public void runALSUsingStaticMethods() { int iterations = 15; int users = 50; int products = 100; - scala.Tuple3, DoubleMatrix, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList( + Tuple3, DoubleMatrix, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList( users, products, features, 0.7, false, false); JavaRDD data = sc.parallelize(testData._1()); @@ -114,14 +125,14 @@ public void runALSUsingConstructor() { int iterations = 15; int users = 100; int products = 200; - scala.Tuple3, DoubleMatrix, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList( + Tuple3, DoubleMatrix, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList( users, products, features, 0.7, false, false); JavaRDD data = sc.parallelize(testData._1()); MatrixFactorizationModel model = new ALS().setRank(features) - .setIterations(iterations) - .run(data.rdd()); + .setIterations(iterations) + .run(data.rdd()); validatePrediction(model, users, products, features, testData._2(), 0.3, false, testData._3()); } @@ -131,7 +142,7 @@ public void runImplicitALSUsingStaticMethods() { int iterations = 15; int users = 80; int products = 160; - scala.Tuple3, DoubleMatrix, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList( + Tuple3, DoubleMatrix, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList( users, products, features, 0.7, true, false); JavaRDD data = sc.parallelize(testData._1()); @@ -145,7 +156,7 @@ public void runImplicitALSUsingConstructor() { int iterations = 15; int users = 100; int products = 200; - scala.Tuple3, DoubleMatrix, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList( + Tuple3, DoubleMatrix, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList( users, products, features, 0.7, true, false); JavaRDD data = sc.parallelize(testData._1()); @@ -163,12 +174,42 @@ public void runImplicitALSWithNegativeWeight() { int iterations = 15; int users = 80; int products = 160; - scala.Tuple3, DoubleMatrix, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList( + Tuple3, DoubleMatrix, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList( users, products, features, 0.7, true, true); JavaRDD data = sc.parallelize(testData._1()); - MatrixFactorizationModel model = ALS.trainImplicit(data.rdd(), features, iterations); + MatrixFactorizationModel model = new ALS().setRank(features) + .setIterations(iterations) + .setImplicitPrefs(true) + .setSeed(8675309L) + .run(data.rdd()); validatePrediction(model, users, products, features, testData._2(), 0.4, true, testData._3()); } + @Test + public void runRecommend() { + int features = 5; + int iterations = 10; + int users = 200; + int products = 50; + Tuple3, DoubleMatrix, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList( + users, products, features, 0.7, true, false); + JavaRDD data = sc.parallelize(testData._1()); + MatrixFactorizationModel model = new ALS().setRank(features) + .setIterations(iterations) + .setImplicitPrefs(true) + .setSeed(8675309L) + .run(data.rdd()); + validateRecommendations(model.recommendProducts(1, 10), 10); + validateRecommendations(model.recommendUsers(1, 20), 20); + } + + private static void validateRecommendations(Rating[] recommendations, int howMany) { + Assert.assertEquals(howMany, recommendations.length); + for (int i = 1; i < recommendations.length; i++) { + Assert.assertTrue(recommendations[i-1].rating() >= recommendations[i].rating()); + } + Assert.assertTrue(recommendations[0].rating() > 0.7); + } + } diff --git a/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java b/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java new file mode 100644 index 0000000000000..2c281a1ee7157 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree; + +import java.io.Serializable; +import java.util.HashMap; +import java.util.List; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.tree.configuration.Algo; +import org.apache.spark.mllib.tree.configuration.Strategy; +import org.apache.spark.mllib.tree.impurity.Gini; +import org.apache.spark.mllib.tree.model.DecisionTreeModel; + + +public class JavaDecisionTreeSuite implements Serializable { + private transient JavaSparkContext sc; + + @Before + public void setUp() { + sc = new JavaSparkContext("local", "JavaDecisionTreeSuite"); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + } + + int validatePrediction(List validationData, DecisionTreeModel model) { + int numCorrect = 0; + for (LabeledPoint point: validationData) { + Double prediction = model.predict(point.features()); + if (prediction == point.label()) { + numCorrect++; + } + } + return numCorrect; + } + + @Test + public void runDTUsingConstructor() { + List arr = DecisionTreeSuite.generateCategoricalDataPointsAsJavaList(); + JavaRDD rdd = sc.parallelize(arr); + HashMap categoricalFeaturesInfo = new HashMap(); + categoricalFeaturesInfo.put(1, 2); // feature 1 has 2 categories + + int maxDepth = 4; + int numClasses = 2; + int maxBins = 100; + Strategy strategy = new Strategy(Algo.Classification(), Gini.instance(), maxDepth, numClasses, + maxBins, categoricalFeaturesInfo); + + DecisionTree learner = new DecisionTree(strategy); + DecisionTreeModel model = learner.train(rdd.rdd()); + + int numCorrect = validatePrediction(arr, model); + Assert.assertTrue(numCorrect == rdd.count()); + } + + @Test + public void runDTUsingStaticMethods() { + List arr = DecisionTreeSuite.generateCategoricalDataPointsAsJavaList(); + JavaRDD rdd = sc.parallelize(arr); + HashMap categoricalFeaturesInfo = new HashMap(); + categoricalFeaturesInfo.put(1, 2); // feature 1 has 2 categories + + int maxDepth = 4; + int numClasses = 2; + int maxBins = 100; + Strategy strategy = new Strategy(Algo.Classification(), Gini.instance(), maxDepth, numClasses, + maxBins, categoricalFeaturesInfo); + + DecisionTreeModel model = DecisionTree$.MODULE$.train(rdd.rdd(), strategy); + + int numCorrect = validatePrediction(arr, model); + Assert.assertTrue(numCorrect == rdd.count()); + } + +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala index d94cfa2fcec81..bd413a80f5107 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.mllib.api.python import org.scalatest.FunSuite -import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.linalg.{Matrices, Vectors} import org.apache.spark.mllib.regression.LabeledPoint class PythonMLLibAPISuite extends FunSuite { @@ -59,10 +59,25 @@ class PythonMLLibAPISuite extends FunSuite { } test("double serialization") { - for (x <- List(123.0, -10.0, 0.0, Double.MaxValue, Double.MinValue)) { + for (x <- List(123.0, -10.0, 0.0, Double.MaxValue, Double.MinValue, Double.NaN)) { val bytes = py.serializeDouble(x) val deser = py.deserializeDouble(bytes) - assert(x === deser) + // We use `equals` here for comparison because we cannot use `==` for NaN + assert(x.equals(deser)) } } + + test("matrix to 2D array") { + val values = Array[Double](0, 1.2, 3, 4.56, 7, 8) + val matrix = Matrices.dense(2, 3, values) + val arr = py.to2dArray(matrix) + val expected = Array(Array[Double](0, 3, 7), Array[Double](1.2, 4.56, 8)) + assert(arr === expected) + + // Test conversion for empty matrix + val empty = Array[Double]() + val emptyMatrix = Matrices.dense(0, 0, empty) + val empty2D = py.to2dArray(emptyMatrix) + assert(empty2D === Array[Array[Double]]()) + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala index 3f6ff859374c7..da7c633bbd2af 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala @@ -26,6 +26,7 @@ import org.scalatest.Matchers import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression._ import org.apache.spark.mllib.util.{LocalClusterSparkContext, LocalSparkContext} +import org.apache.spark.mllib.util.TestingUtils._ object LogisticRegressionSuite { @@ -81,9 +82,8 @@ class LogisticRegressionSuite extends FunSuite with LocalSparkContext with Match val model = lr.run(testRDD) // Test the weights - val weight0 = model.weights(0) - assert(weight0 >= -1.60 && weight0 <= -1.40, weight0 + " not in [-1.6, -1.4]") - assert(model.intercept >= 1.9 && model.intercept <= 2.1, model.intercept + " not in [1.9, 2.1]") + assert(model.weights(0) ~== -1.52 relTol 0.01) + assert(model.intercept ~== 2.00 relTol 0.01) val validationData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 17) val validationRDD = sc.parallelize(validationData, 2) @@ -113,9 +113,9 @@ class LogisticRegressionSuite extends FunSuite with LocalSparkContext with Match val model = lr.run(testRDD, initialWeights) - val weight0 = model.weights(0) - assert(weight0 >= -1.60 && weight0 <= -1.40, weight0 + " not in [-1.6, -1.4]") - assert(model.intercept >= 1.9 && model.intercept <= 2.1, model.intercept + " not in [1.9, 2.1]") + // Test the weights + assert(model.weights(0) ~== -1.50 relTol 0.01) + assert(model.intercept ~== 1.97 relTol 0.01) val validationData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 17) val validationRDD = sc.parallelize(validationData, 2) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala index 34bc4537a7b3a..afa1f79b95a12 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala @@ -21,8 +21,9 @@ import scala.util.Random import org.scalatest.FunSuite -import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.{LocalClusterSparkContext, LocalSparkContext} +import org.apache.spark.mllib.util.TestingUtils._ class KMeansSuite extends FunSuite with LocalSparkContext { @@ -41,26 +42,26 @@ class KMeansSuite extends FunSuite with LocalSparkContext { // centered at the mean of the points var model = KMeans.train(data, k = 1, maxIterations = 1) - assert(model.clusterCenters.head === center) + assert(model.clusterCenters.head ~== center absTol 1E-5) model = KMeans.train(data, k = 1, maxIterations = 2) - assert(model.clusterCenters.head === center) + assert(model.clusterCenters.head ~== center absTol 1E-5) model = KMeans.train(data, k = 1, maxIterations = 5) - assert(model.clusterCenters.head === center) + assert(model.clusterCenters.head ~== center absTol 1E-5) model = KMeans.train(data, k = 1, maxIterations = 1, runs = 5) - assert(model.clusterCenters.head === center) + assert(model.clusterCenters.head ~== center absTol 1E-5) model = KMeans.train(data, k = 1, maxIterations = 1, runs = 5) - assert(model.clusterCenters.head === center) + assert(model.clusterCenters.head ~== center absTol 1E-5) model = KMeans.train(data, k = 1, maxIterations = 1, runs = 1, initializationMode = RANDOM) - assert(model.clusterCenters.head === center) + assert(model.clusterCenters.head ~== center absTol 1E-5) model = KMeans.train( data, k = 1, maxIterations = 1, runs = 1, initializationMode = K_MEANS_PARALLEL) - assert(model.clusterCenters.head === center) + assert(model.clusterCenters.head ~== center absTol 1E-5) } test("no distinct points") { @@ -104,26 +105,26 @@ class KMeansSuite extends FunSuite with LocalSparkContext { var model = KMeans.train(data, k = 1, maxIterations = 1) assert(model.clusterCenters.size === 1) - assert(model.clusterCenters.head === center) + assert(model.clusterCenters.head ~== center absTol 1E-5) model = KMeans.train(data, k = 1, maxIterations = 2) - assert(model.clusterCenters.head === center) + assert(model.clusterCenters.head ~== center absTol 1E-5) model = KMeans.train(data, k = 1, maxIterations = 5) - assert(model.clusterCenters.head === center) + assert(model.clusterCenters.head ~== center absTol 1E-5) model = KMeans.train(data, k = 1, maxIterations = 1, runs = 5) - assert(model.clusterCenters.head === center) + assert(model.clusterCenters.head ~== center absTol 1E-5) model = KMeans.train(data, k = 1, maxIterations = 1, runs = 5) - assert(model.clusterCenters.head === center) + assert(model.clusterCenters.head ~== center absTol 1E-5) model = KMeans.train(data, k = 1, maxIterations = 1, runs = 1, initializationMode = RANDOM) - assert(model.clusterCenters.head === center) + assert(model.clusterCenters.head ~== center absTol 1E-5) model = KMeans.train(data, k = 1, maxIterations = 1, runs = 1, initializationMode = K_MEANS_PARALLEL) - assert(model.clusterCenters.head === center) + assert(model.clusterCenters.head ~== center absTol 1E-5) } test("single cluster with sparse data") { @@ -149,31 +150,39 @@ class KMeansSuite extends FunSuite with LocalSparkContext { val center = Vectors.sparse(n, Seq((0, 1.0), (1, 3.0), (2, 4.0))) var model = KMeans.train(data, k = 1, maxIterations = 1) - assert(model.clusterCenters.head === center) + assert(model.clusterCenters.head ~== center absTol 1E-5) model = KMeans.train(data, k = 1, maxIterations = 2) - assert(model.clusterCenters.head === center) + assert(model.clusterCenters.head ~== center absTol 1E-5) model = KMeans.train(data, k = 1, maxIterations = 5) - assert(model.clusterCenters.head === center) + assert(model.clusterCenters.head ~== center absTol 1E-5) model = KMeans.train(data, k = 1, maxIterations = 1, runs = 5) - assert(model.clusterCenters.head === center) + assert(model.clusterCenters.head ~== center absTol 1E-5) model = KMeans.train(data, k = 1, maxIterations = 1, runs = 5) - assert(model.clusterCenters.head === center) + assert(model.clusterCenters.head ~== center absTol 1E-5) model = KMeans.train(data, k = 1, maxIterations = 1, runs = 1, initializationMode = RANDOM) - assert(model.clusterCenters.head === center) + assert(model.clusterCenters.head ~== center absTol 1E-5) model = KMeans.train(data, k = 1, maxIterations = 1, runs = 1, initializationMode = K_MEANS_PARALLEL) - assert(model.clusterCenters.head === center) + assert(model.clusterCenters.head ~== center absTol 1E-5) data.unpersist() } test("k-means|| initialization") { + + case class VectorWithCompare(x: Vector) extends Ordered[VectorWithCompare] { + @Override def compare(that: VectorWithCompare): Int = { + if(this.x.toArray.foldLeft[Double](0.0)((acc, x) => acc + x * x) > + that.x.toArray.foldLeft[Double](0.0)((acc, x) => acc + x * x)) -1 else 1 + } + } + val points = Seq( Vectors.dense(1.0, 2.0, 6.0), Vectors.dense(1.0, 3.0, 0.0), @@ -188,15 +197,19 @@ class KMeansSuite extends FunSuite with LocalSparkContext { // unselected point as long as it hasn't yet selected all of them var model = KMeans.train(rdd, k = 5, maxIterations = 1) - assert(Set(model.clusterCenters: _*) === Set(points: _*)) + + assert(model.clusterCenters.sortBy(VectorWithCompare(_)) + .zip(points.sortBy(VectorWithCompare(_))).forall(x => x._1 ~== (x._2) absTol 1E-5)) // Iterations of Lloyd's should not change the answer either model = KMeans.train(rdd, k = 5, maxIterations = 10) - assert(Set(model.clusterCenters: _*) === Set(points: _*)) + assert(model.clusterCenters.sortBy(VectorWithCompare(_)) + .zip(points.sortBy(VectorWithCompare(_))).forall(x => x._1 ~== (x._2) absTol 1E-5)) // Neither should more runs model = KMeans.train(rdd, k = 5, maxIterations = 10, runs = 5) - assert(Set(model.clusterCenters: _*) === Set(points: _*)) + assert(model.clusterCenters.sortBy(VectorWithCompare(_)) + .zip(points.sortBy(VectorWithCompare(_))).forall(x => x._1 ~== (x._2) absTol 1E-5)) } test("two clusters") { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/AreaUnderCurveSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/AreaUnderCurveSuite.scala index 1c9844f289fe0..994e0feb8629e 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/AreaUnderCurveSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/AreaUnderCurveSuite.scala @@ -20,27 +20,28 @@ package org.apache.spark.mllib.evaluation import org.scalatest.FunSuite import org.apache.spark.mllib.util.LocalSparkContext +import org.apache.spark.mllib.util.TestingUtils._ class AreaUnderCurveSuite extends FunSuite with LocalSparkContext { test("auc computation") { val curve = Seq((0.0, 0.0), (1.0, 1.0), (2.0, 3.0), (3.0, 0.0)) val auc = 4.0 - assert(AreaUnderCurve.of(curve) === auc) + assert(AreaUnderCurve.of(curve) ~== auc absTol 1E-5) val rddCurve = sc.parallelize(curve, 2) - assert(AreaUnderCurve.of(rddCurve) == auc) + assert(AreaUnderCurve.of(rddCurve) ~== auc absTol 1E-5) } test("auc of an empty curve") { val curve = Seq.empty[(Double, Double)] - assert(AreaUnderCurve.of(curve) === 0.0) + assert(AreaUnderCurve.of(curve) ~== 0.0 absTol 1E-5) val rddCurve = sc.parallelize(curve, 2) - assert(AreaUnderCurve.of(rddCurve) === 0.0) + assert(AreaUnderCurve.of(rddCurve) ~== 0.0 absTol 1E-5) } test("auc of a curve with a single point") { val curve = Seq((1.0, 1.0)) - assert(AreaUnderCurve.of(curve) === 0.0) + assert(AreaUnderCurve.of(curve) ~== 0.0 absTol 1E-5) val rddCurve = sc.parallelize(curve, 2) - assert(AreaUnderCurve.of(rddCurve) === 0.0) + assert(AreaUnderCurve.of(rddCurve) ~== 0.0 absTol 1E-5) } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala index 94db1dc183230..a733f88b60b80 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala @@ -20,25 +20,14 @@ package org.apache.spark.mllib.evaluation import org.scalatest.FunSuite import org.apache.spark.mllib.util.LocalSparkContext -import org.apache.spark.mllib.util.TestingUtils.DoubleWithAlmostEquals +import org.apache.spark.mllib.util.TestingUtils._ class BinaryClassificationMetricsSuite extends FunSuite with LocalSparkContext { - // TODO: move utility functions to TestingUtils. + def cond1(x: (Double, Double)): Boolean = x._1 ~= (x._2) absTol 1E-5 - def elementsAlmostEqual(actual: Seq[Double], expected: Seq[Double]): Boolean = { - actual.zip(expected).forall { case (x1, x2) => - x1.almostEquals(x2) - } - } - - def elementsAlmostEqual( - actual: Seq[(Double, Double)], - expected: Seq[(Double, Double)])(implicit dummy: DummyImplicit): Boolean = { - actual.zip(expected).forall { case ((x1, y1), (x2, y2)) => - x1.almostEquals(x2) && y1.almostEquals(y2) - } - } + def cond2(x: ((Double, Double), (Double, Double))): Boolean = + (x._1._1 ~= x._2._1 absTol 1E-5) && (x._1._2 ~= x._2._2 absTol 1E-5) test("binary evaluation metrics") { val scoreAndLabels = sc.parallelize( @@ -57,16 +46,17 @@ class BinaryClassificationMetricsSuite extends FunSuite with LocalSparkContext { val rocCurve = Seq((0.0, 0.0)) ++ fpr.zip(recall) ++ Seq((1.0, 1.0)) val pr = recall.zip(precision) val prCurve = Seq((0.0, 1.0)) ++ pr - val f1 = pr.map { case (r, p) => 2.0 * (p * r) / (p + r) } + val f1 = pr.map { case (r, p) => 2.0 * (p * r) / (p + r)} val f2 = pr.map { case (r, p) => 5.0 * (p * r) / (4.0 * p + r)} - assert(elementsAlmostEqual(metrics.thresholds().collect(), threshold)) - assert(elementsAlmostEqual(metrics.roc().collect(), rocCurve)) - assert(metrics.areaUnderROC().almostEquals(AreaUnderCurve.of(rocCurve))) - assert(elementsAlmostEqual(metrics.pr().collect(), prCurve)) - assert(metrics.areaUnderPR().almostEquals(AreaUnderCurve.of(prCurve))) - assert(elementsAlmostEqual(metrics.fMeasureByThreshold().collect(), threshold.zip(f1))) - assert(elementsAlmostEqual(metrics.fMeasureByThreshold(2.0).collect(), threshold.zip(f2))) - assert(elementsAlmostEqual(metrics.precisionByThreshold().collect(), threshold.zip(precision))) - assert(elementsAlmostEqual(metrics.recallByThreshold().collect(), threshold.zip(recall))) + + assert(metrics.thresholds().collect().zip(threshold).forall(cond1)) + assert(metrics.roc().collect().zip(rocCurve).forall(cond2)) + assert(metrics.areaUnderROC() ~== AreaUnderCurve.of(rocCurve) absTol 1E-5) + assert(metrics.pr().collect().zip(prCurve).forall(cond2)) + assert(metrics.areaUnderPR() ~== AreaUnderCurve.of(prCurve) absTol 1E-5) + assert(metrics.fMeasureByThreshold().collect().zip(threshold.zip(f1)).forall(cond2)) + assert(metrics.fMeasureByThreshold(2.0).collect().zip(threshold.zip(f2)).forall(cond2)) + assert(metrics.precisionByThreshold().collect().zip(threshold.zip(precision)).forall(cond2)) + assert(metrics.recallByThreshold().collect().zip(threshold.zip(recall)).forall(cond2)) } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/HashingTFSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/HashingTFSuite.scala new file mode 100644 index 0000000000000..a599e0d938569 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/HashingTFSuite.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.feature + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.util.LocalSparkContext + +class HashingTFSuite extends FunSuite with LocalSparkContext { + + test("hashing tf on a single doc") { + val hashingTF = new HashingTF(1000) + val doc = "a a b b c d".split(" ") + val n = hashingTF.numFeatures + val termFreqs = Seq( + (hashingTF.indexOf("a"), 2.0), + (hashingTF.indexOf("b"), 2.0), + (hashingTF.indexOf("c"), 1.0), + (hashingTF.indexOf("d"), 1.0)) + assert(termFreqs.map(_._1).forall(i => i >= 0 && i < n), + "index must be in range [0, #features)") + assert(termFreqs.map(_._1).toSet.size === 4, "expecting perfect hashing") + val expected = Vectors.sparse(n, termFreqs) + assert(hashingTF.transform(doc) === expected) + } + + test("hashing tf on an RDD") { + val hashingTF = new HashingTF + val localDocs: Seq[Seq[String]] = Seq( + "a a b b b c d".split(" "), + "a b c d a b c".split(" "), + "c b a c b a a".split(" ")) + val docs = sc.parallelize(localDocs, 2) + assert(hashingTF.transform(docs).collect().toSet === localDocs.map(hashingTF.transform).toSet) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/IDFSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/IDFSuite.scala new file mode 100644 index 0000000000000..78a2804ff204b --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/IDFSuite.scala @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.feature + +import org.scalatest.FunSuite + +import org.apache.spark.SparkContext._ +import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors} +import org.apache.spark.mllib.util.LocalSparkContext +import org.apache.spark.mllib.util.TestingUtils._ + +class IDFSuite extends FunSuite with LocalSparkContext { + + test("idf") { + val n = 4 + val localTermFrequencies = Seq( + Vectors.sparse(n, Array(1, 3), Array(1.0, 2.0)), + Vectors.dense(0.0, 1.0, 2.0, 3.0), + Vectors.sparse(n, Array(1), Array(1.0)) + ) + val m = localTermFrequencies.size + val termFrequencies = sc.parallelize(localTermFrequencies, 2) + val idf = new IDF + intercept[IllegalStateException] { + idf.idf() + } + intercept[IllegalStateException] { + idf.transform(termFrequencies) + } + idf.fit(termFrequencies) + val expected = Vectors.dense(Array(0, 3, 1, 2).map { x => + math.log((m.toDouble + 1.0) / (x + 1.0)) + }) + assert(idf.idf() ~== expected absTol 1e-12) + val tfidf = idf.transform(termFrequencies).cache().zipWithIndex().map(_.swap).collectAsMap() + assert(tfidf.size === 3) + val tfidf0 = tfidf(0L).asInstanceOf[SparseVector] + assert(tfidf0.indices === Array(1, 3)) + assert(Vectors.dense(tfidf0.values) ~== + Vectors.dense(1.0 * expected(1), 2.0 * expected(3)) absTol 1e-12) + val tfidf1 = tfidf(1L).asInstanceOf[DenseVector] + assert(Vectors.dense(tfidf1.values) ~== + Vectors.dense(0.0, 1.0 * expected(1), 2.0 * expected(2), 3.0 * expected(3)) absTol 1e-12) + val tfidf2 = tfidf(2L).asInstanceOf[SparseVector] + assert(tfidf2.indices === Array(1)) + assert(tfidf2.values(0) ~== (1.0 * expected(1)) absTol 1e-12) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/NormalizerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/NormalizerSuite.scala new file mode 100644 index 0000000000000..fb76dccfdf79e --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/NormalizerSuite.scala @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.feature + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors} +import org.apache.spark.mllib.util.LocalSparkContext +import org.apache.spark.mllib.util.TestingUtils._ + +class NormalizerSuite extends FunSuite with LocalSparkContext { + + val data = Array( + Vectors.sparse(3, Seq((0, -2.0), (1, 2.3))), + Vectors.dense(0.0, 0.0, 0.0), + Vectors.dense(0.6, -1.1, -3.0), + Vectors.sparse(3, Seq((1, 0.91), (2, 3.2))), + Vectors.sparse(3, Seq((0, 5.7), (1, 0.72), (2, 2.7))), + Vectors.sparse(3, Seq()) + ) + + lazy val dataRDD = sc.parallelize(data, 3) + + test("Normalization using L1 distance") { + val l1Normalizer = new Normalizer(1) + + val data1 = data.map(l1Normalizer.transform) + val data1RDD = l1Normalizer.transform(dataRDD) + + assert((data, data1, data1RDD.collect()).zipped.forall { + case (v1: DenseVector, v2: DenseVector, v3: DenseVector) => true + case (v1: SparseVector, v2: SparseVector, v3: SparseVector) => true + case _ => false + }, "The vector type should be preserved after normalization.") + + assert((data1, data1RDD.collect()).zipped.forall((v1, v2) => v1 ~== v2 absTol 1E-5)) + + assert(data1(0).toBreeze.norm(1) ~== 1.0 absTol 1E-5) + assert(data1(2).toBreeze.norm(1) ~== 1.0 absTol 1E-5) + assert(data1(3).toBreeze.norm(1) ~== 1.0 absTol 1E-5) + assert(data1(4).toBreeze.norm(1) ~== 1.0 absTol 1E-5) + + assert(data1(0) ~== Vectors.sparse(3, Seq((0, -0.465116279), (1, 0.53488372))) absTol 1E-5) + assert(data1(1) ~== Vectors.dense(0.0, 0.0, 0.0) absTol 1E-5) + assert(data1(2) ~== Vectors.dense(0.12765957, -0.23404255, -0.63829787) absTol 1E-5) + assert(data1(3) ~== Vectors.sparse(3, Seq((1, 0.22141119), (2, 0.7785888))) absTol 1E-5) + assert(data1(4) ~== Vectors.dense(0.625, 0.07894737, 0.29605263) absTol 1E-5) + assert(data1(5) ~== Vectors.sparse(3, Seq()) absTol 1E-5) + } + + test("Normalization using L2 distance") { + val l2Normalizer = new Normalizer() + + val data2 = data.map(l2Normalizer.transform) + val data2RDD = l2Normalizer.transform(dataRDD) + + assert((data, data2, data2RDD.collect()).zipped.forall { + case (v1: DenseVector, v2: DenseVector, v3: DenseVector) => true + case (v1: SparseVector, v2: SparseVector, v3: SparseVector) => true + case _ => false + }, "The vector type should be preserved after normalization.") + + assert((data2, data2RDD.collect()).zipped.forall((v1, v2) => v1 ~== v2 absTol 1E-5)) + + assert(data2(0).toBreeze.norm(2) ~== 1.0 absTol 1E-5) + assert(data2(2).toBreeze.norm(2) ~== 1.0 absTol 1E-5) + assert(data2(3).toBreeze.norm(2) ~== 1.0 absTol 1E-5) + assert(data2(4).toBreeze.norm(2) ~== 1.0 absTol 1E-5) + + assert(data2(0) ~== Vectors.sparse(3, Seq((0, -0.65617871), (1, 0.75460552))) absTol 1E-5) + assert(data2(1) ~== Vectors.dense(0.0, 0.0, 0.0) absTol 1E-5) + assert(data2(2) ~== Vectors.dense(0.184549876, -0.3383414, -0.922749378) absTol 1E-5) + assert(data2(3) ~== Vectors.sparse(3, Seq((1, 0.27352993), (2, 0.96186349))) absTol 1E-5) + assert(data2(4) ~== Vectors.dense(0.897906166, 0.113419726, 0.42532397) absTol 1E-5) + assert(data2(5) ~== Vectors.sparse(3, Seq()) absTol 1E-5) + } + + test("Normalization using L^Inf distance.") { + val lInfNormalizer = new Normalizer(Double.PositiveInfinity) + + val dataInf = data.map(lInfNormalizer.transform) + val dataInfRDD = lInfNormalizer.transform(dataRDD) + + assert((data, dataInf, dataInfRDD.collect()).zipped.forall { + case (v1: DenseVector, v2: DenseVector, v3: DenseVector) => true + case (v1: SparseVector, v2: SparseVector, v3: SparseVector) => true + case _ => false + }, "The vector type should be preserved after normalization.") + + assert((dataInf, dataInfRDD.collect()).zipped.forall((v1, v2) => v1 ~== v2 absTol 1E-5)) + + assert(dataInf(0).toArray.map(Math.abs).max ~== 1.0 absTol 1E-5) + assert(dataInf(2).toArray.map(Math.abs).max ~== 1.0 absTol 1E-5) + assert(dataInf(3).toArray.map(Math.abs).max ~== 1.0 absTol 1E-5) + assert(dataInf(4).toArray.map(Math.abs).max ~== 1.0 absTol 1E-5) + + assert(dataInf(0) ~== Vectors.sparse(3, Seq((0, -0.86956522), (1, 1.0))) absTol 1E-5) + assert(dataInf(1) ~== Vectors.dense(0.0, 0.0, 0.0) absTol 1E-5) + assert(dataInf(2) ~== Vectors.dense(0.2, -0.36666667, -1.0) absTol 1E-5) + assert(dataInf(3) ~== Vectors.sparse(3, Seq((1, 0.284375), (2, 1.0))) absTol 1E-5) + assert(dataInf(4) ~== Vectors.dense(1.0, 0.12631579, 0.473684211) absTol 1E-5) + assert(dataInf(5) ~== Vectors.sparse(3, Seq()) absTol 1E-5) + } + +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala new file mode 100644 index 0000000000000..5a9be923a8625 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala @@ -0,0 +1,200 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.feature + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} +import org.apache.spark.mllib.util.LocalSparkContext +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.mllib.rdd.RDDFunctions._ +import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, MultivariateOnlineSummarizer} +import org.apache.spark.rdd.RDD + +class StandardScalerSuite extends FunSuite with LocalSparkContext { + + private def computeSummary(data: RDD[Vector]): MultivariateStatisticalSummary = { + data.treeAggregate(new MultivariateOnlineSummarizer)( + (aggregator, data) => aggregator.add(data), + (aggregator1, aggregator2) => aggregator1.merge(aggregator2)) + } + + test("Standardization with dense input") { + val data = Array( + Vectors.dense(-2.0, 2.3, 0), + Vectors.dense(0.0, -1.0, -3.0), + Vectors.dense(0.0, -5.1, 0.0), + Vectors.dense(3.8, 0.0, 1.9), + Vectors.dense(1.7, -0.6, 0.0), + Vectors.dense(0.0, 1.9, 0.0) + ) + + val dataRDD = sc.parallelize(data, 3) + + val standardizer1 = new StandardScaler(withMean = true, withStd = true) + val standardizer2 = new StandardScaler() + val standardizer3 = new StandardScaler(withMean = true, withStd = false) + + withClue("Using a standardizer before fitting the model should throw exception.") { + intercept[IllegalStateException] { + data.map(standardizer1.transform) + } + } + + standardizer1.fit(dataRDD) + standardizer2.fit(dataRDD) + standardizer3.fit(dataRDD) + + val data1 = data.map(standardizer1.transform) + val data2 = data.map(standardizer2.transform) + val data3 = data.map(standardizer3.transform) + + val data1RDD = standardizer1.transform(dataRDD) + val data2RDD = standardizer2.transform(dataRDD) + val data3RDD = standardizer3.transform(dataRDD) + + val summary = computeSummary(dataRDD) + val summary1 = computeSummary(data1RDD) + val summary2 = computeSummary(data2RDD) + val summary3 = computeSummary(data3RDD) + + assert((data, data1, data1RDD.collect()).zipped.forall { + case (v1: DenseVector, v2: DenseVector, v3: DenseVector) => true + case (v1: SparseVector, v2: SparseVector, v3: SparseVector) => true + case _ => false + }, "The vector type should be preserved after standardization.") + + assert((data, data2, data2RDD.collect()).zipped.forall { + case (v1: DenseVector, v2: DenseVector, v3: DenseVector) => true + case (v1: SparseVector, v2: SparseVector, v3: SparseVector) => true + case _ => false + }, "The vector type should be preserved after standardization.") + + assert((data, data3, data3RDD.collect()).zipped.forall { + case (v1: DenseVector, v2: DenseVector, v3: DenseVector) => true + case (v1: SparseVector, v2: SparseVector, v3: SparseVector) => true + case _ => false + }, "The vector type should be preserved after standardization.") + + assert((data1, data1RDD.collect()).zipped.forall((v1, v2) => v1 ~== v2 absTol 1E-5)) + assert((data2, data2RDD.collect()).zipped.forall((v1, v2) => v1 ~== v2 absTol 1E-5)) + assert((data3, data3RDD.collect()).zipped.forall((v1, v2) => v1 ~== v2 absTol 1E-5)) + + assert(summary1.mean ~== Vectors.dense(0.0, 0.0, 0.0) absTol 1E-5) + assert(summary1.variance ~== Vectors.dense(1.0, 1.0, 1.0) absTol 1E-5) + + assert(summary2.mean !~== Vectors.dense(0.0, 0.0, 0.0) absTol 1E-5) + assert(summary2.variance ~== Vectors.dense(1.0, 1.0, 1.0) absTol 1E-5) + + assert(summary3.mean ~== Vectors.dense(0.0, 0.0, 0.0) absTol 1E-5) + assert(summary3.variance ~== summary.variance absTol 1E-5) + + assert(data1(0) ~== Vectors.dense(-1.31527964, 1.023470449, 0.11637768424) absTol 1E-5) + assert(data1(3) ~== Vectors.dense(1.637735298, 0.156973995, 1.32247368462) absTol 1E-5) + assert(data2(4) ~== Vectors.dense(0.865538862, -0.22604255, 0.0) absTol 1E-5) + assert(data2(5) ~== Vectors.dense(0.0, 0.71580142, 0.0) absTol 1E-5) + assert(data3(1) ~== Vectors.dense(-0.58333333, -0.58333333, -2.8166666666) absTol 1E-5) + assert(data3(5) ~== Vectors.dense(-0.58333333, 2.316666666, 0.18333333333) absTol 1E-5) + } + + + test("Standardization with sparse input") { + val data = Array( + Vectors.sparse(3, Seq((0, -2.0), (1, 2.3))), + Vectors.sparse(3, Seq((1, -1.0), (2, -3.0))), + Vectors.sparse(3, Seq((1, -5.1))), + Vectors.sparse(3, Seq((0, 3.8), (2, 1.9))), + Vectors.sparse(3, Seq((0, 1.7), (1, -0.6))), + Vectors.sparse(3, Seq((1, 1.9))) + ) + + val dataRDD = sc.parallelize(data, 3) + + val standardizer1 = new StandardScaler(withMean = true, withStd = true) + val standardizer2 = new StandardScaler() + val standardizer3 = new StandardScaler(withMean = true, withStd = false) + + standardizer1.fit(dataRDD) + standardizer2.fit(dataRDD) + standardizer3.fit(dataRDD) + + val data2 = data.map(standardizer2.transform) + + withClue("Standardization with mean can not be applied on sparse input.") { + intercept[IllegalArgumentException] { + data.map(standardizer1.transform) + } + } + + withClue("Standardization with mean can not be applied on sparse input.") { + intercept[IllegalArgumentException] { + data.map(standardizer3.transform) + } + } + + val data2RDD = standardizer2.transform(dataRDD) + + val summary2 = computeSummary(data2RDD) + + assert((data, data2, data2RDD.collect()).zipped.forall { + case (v1: DenseVector, v2: DenseVector, v3: DenseVector) => true + case (v1: SparseVector, v2: SparseVector, v3: SparseVector) => true + case _ => false + }, "The vector type should be preserved after standardization.") + + assert((data2, data2RDD.collect()).zipped.forall((v1, v2) => v1 ~== v2 absTol 1E-5)) + + assert(summary2.mean !~== Vectors.dense(0.0, 0.0, 0.0) absTol 1E-5) + assert(summary2.variance ~== Vectors.dense(1.0, 1.0, 1.0) absTol 1E-5) + + assert(data2(4) ~== Vectors.sparse(3, Seq((0, 0.865538862), (1, -0.22604255))) absTol 1E-5) + assert(data2(5) ~== Vectors.sparse(3, Seq((1, 0.71580142))) absTol 1E-5) + } + + test("Standardization with constant input") { + // When the input data is all constant, the variance is zero. The standardization against + // zero variance is not well-defined, but we decide to just set it into zero here. + val data = Array( + Vectors.dense(2.0), + Vectors.dense(2.0), + Vectors.dense(2.0) + ) + + val dataRDD = sc.parallelize(data, 2) + + val standardizer1 = new StandardScaler(withMean = true, withStd = true) + val standardizer2 = new StandardScaler(withMean = true, withStd = false) + val standardizer3 = new StandardScaler(withMean = false, withStd = true) + + standardizer1.fit(dataRDD) + standardizer2.fit(dataRDD) + standardizer3.fit(dataRDD) + + val data1 = data.map(standardizer1.transform) + val data2 = data.map(standardizer2.transform) + val data3 = data.map(standardizer3.transform) + + assert(data1.forall(_.toArray.forall(_ == 0.0)), + "The variance is zero, so the transformed result should be 0.0") + assert(data2.forall(_.toArray.forall(_ == 0.0)), + "The variance is zero, so the transformed result should be 0.0") + assert(data3.forall(_.toArray.forall(_ == 0.0)), + "The variance is zero, so the transformed result should be 0.0") + } + +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala new file mode 100644 index 0000000000000..e34335d89eb75 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.feature + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.util.LocalSparkContext + +class Word2VecSuite extends FunSuite with LocalSparkContext { + + // TODO: add more tests + + test("Word2Vec") { + val sentence = "a b " * 100 + "a c " * 10 + val localDoc = Seq(sentence, sentence) + val doc = sc.parallelize(localDoc) + .map(line => line.split(" ").toSeq) + val model = new Word2Vec().setVectorSize(10).setSeed(42L).fit(doc) + val syms = model.findSynonyms("a", 2) + assert(syms.length == 2) + assert(syms(0)._1 == "b") + assert(syms(1)._1 == "c") + } + + test("Word2VecModel") { + val num = 2 + val word2VecMap = Map( + ("china", Array(0.50f, 0.50f, 0.50f, 0.50f)), + ("japan", Array(0.40f, 0.50f, 0.50f, 0.50f)), + ("taiwan", Array(0.60f, 0.50f, 0.50f, 0.50f)), + ("korea", Array(0.45f, 0.60f, 0.60f, 0.60f)) + ) + val model = new Word2VecModel(word2VecMap) + val syms = model.findSynonyms("china", num) + assert(syms.length == num) + assert(syms(0)._1 == "taiwan") + assert(syms(1)._1 == "japan") + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala index dfb2eb7f0d14e..bf040110e228b 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala @@ -25,6 +25,7 @@ import org.scalatest.{FunSuite, Matchers} import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression._ import org.apache.spark.mllib.util.{LocalClusterSparkContext, LocalSparkContext} +import org.apache.spark.mllib.util.TestingUtils._ object GradientDescentSuite { @@ -126,19 +127,14 @@ class GradientDescentSuite extends FunSuite with LocalSparkContext with Matchers val (newWeights1, loss1) = GradientDescent.runMiniBatchSGD( dataRDD, gradient, updater, 1, 1, regParam1, 1.0, initialWeightsWithIntercept) - def compareDouble(x: Double, y: Double, tol: Double = 1E-3): Boolean = { - math.abs(x - y) / (math.abs(y) + 1e-15) < tol - } - - assert(compareDouble( - loss1(0), - loss0(0) + (math.pow(initialWeightsWithIntercept(0), 2) + - math.pow(initialWeightsWithIntercept(1), 2)) / 2), + assert( + loss1(0) ~= (loss0(0) + (math.pow(initialWeightsWithIntercept(0), 2) + + math.pow(initialWeightsWithIntercept(1), 2)) / 2) absTol 1E-5, """For non-zero weights, the regVal should be \frac{1}{2}\sum_i w_i^2.""") assert( - compareDouble(newWeights1(0) , newWeights0(0) - initialWeightsWithIntercept(0)) && - compareDouble(newWeights1(1) , newWeights0(1) - initialWeightsWithIntercept(1)), + (newWeights1(0) ~= (newWeights0(0) - initialWeightsWithIntercept(0)) absTol 1E-5) && + (newWeights1(1) ~= (newWeights0(1) - initialWeightsWithIntercept(1)) absTol 1E-5), "The different between newWeights with/without regularization " + "should be initialWeightsWithIntercept.") } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala index ff414742e8393..5f4c24115ac80 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala @@ -24,6 +24,7 @@ import org.scalatest.{FunSuite, Matchers} import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.{LocalClusterSparkContext, LocalSparkContext} +import org.apache.spark.mllib.util.TestingUtils._ class LBFGSSuite extends FunSuite with LocalSparkContext with Matchers { @@ -49,10 +50,6 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with Matchers { lazy val dataRDD = sc.parallelize(data, 2).cache() - def compareDouble(x: Double, y: Double, tol: Double = 1E-3): Boolean = { - math.abs(x - y) / (math.abs(y) + 1e-15) < tol - } - test("LBFGS loss should be decreasing and match the result of Gradient Descent.") { val regParam = 0 @@ -126,15 +123,15 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with Matchers { miniBatchFrac, initialWeightsWithIntercept) - assert(compareDouble(lossGD(0), lossLBFGS(0)), + assert(lossGD(0) ~= lossLBFGS(0) absTol 1E-5, "The first losses of LBFGS and GD should be the same.") // The 2% difference here is based on observation, but is not theoretically guaranteed. - assert(compareDouble(lossGD.last, lossLBFGS.last, 0.02), + assert(lossGD.last ~= lossLBFGS.last relTol 0.02, "The last losses of LBFGS and GD should be within 2% difference.") - assert(compareDouble(weightLBFGS(0), weightGD(0), 0.02) && - compareDouble(weightLBFGS(1), weightGD(1), 0.02), + assert( + (weightLBFGS(0) ~= weightGD(0) relTol 0.02) && (weightLBFGS(1) ~= weightGD(1) relTol 0.02), "The weight differences between LBFGS and GD should be within 2%.") } @@ -226,8 +223,8 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with Matchers { initialWeightsWithIntercept) // for class LBFGS and the optimize method, we only look at the weights - assert(compareDouble(weightLBFGS(0), weightGD(0), 0.02) && - compareDouble(weightLBFGS(1), weightGD(1), 0.02), + assert( + (weightLBFGS(0) ~= weightGD(0) relTol 0.02) && (weightLBFGS(1) ~= weightGD(1) relTol 0.02), "The weight differences between LBFGS and GD should be within 2%.") } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/NNLSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/NNLSSuite.scala index bbf385229081a..b781a6aed9a8c 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/NNLSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/NNLSSuite.scala @@ -21,7 +21,9 @@ import scala.util.Random import org.scalatest.FunSuite -import org.jblas.{DoubleMatrix, SimpleBlas, NativeBlas} +import org.jblas.{DoubleMatrix, SimpleBlas} + +import org.apache.spark.mllib.util.TestingUtils._ class NNLSSuite extends FunSuite { /** Generate an NNLS problem whose optimal solution is the all-ones vector. */ @@ -73,7 +75,7 @@ class NNLSSuite extends FunSuite { val ws = NNLS.createWorkspace(n) val x = NNLS.solve(ata, atb, ws) for (i <- 0 until n) { - assert(Math.abs(x(i) - goodx(i)) < 1e-3) + assert(x(i) ~== goodx(i) absTol 1E-3) assert(x(i) >= 0) } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/random/DistributionGeneratorSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomDataGeneratorSuite.scala similarity index 95% rename from mllib/src/test/scala/org/apache/spark/mllib/random/DistributionGeneratorSuite.scala rename to mllib/src/test/scala/org/apache/spark/mllib/random/RandomDataGeneratorSuite.scala index 974dec4c0b5ee..3df7c128af5ab 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/random/DistributionGeneratorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomDataGeneratorSuite.scala @@ -22,9 +22,9 @@ import org.scalatest.FunSuite import org.apache.spark.util.StatCounter // TODO update tests to use TestingUtils for floating point comparison after PR 1367 is merged -class DistributionGeneratorSuite extends FunSuite { +class RandomDataGeneratorSuite extends FunSuite { - def apiChecks(gen: DistributionGenerator) { + def apiChecks(gen: RandomDataGenerator[Double]) { // resetting seed should generate the same sequence of random numbers gen.setSeed(42L) @@ -53,7 +53,7 @@ class DistributionGeneratorSuite extends FunSuite { assert(array5.equals(array6)) } - def distributionChecks(gen: DistributionGenerator, + def distributionChecks(gen: RandomDataGenerator[Double], mean: Double = 0.0, stddev: Double = 1.0, epsilon: Double = 0.01) { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDGeneratorsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDGeneratorsSuite.scala index 6aa4f803df0f7..96e0bc63b0fa4 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDGeneratorsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDGeneratorsSuite.scala @@ -78,7 +78,9 @@ class RandomRDDGeneratorsSuite extends FunSuite with LocalSparkContext with Seri assert(rdd.partitions.size === numPartitions) // check that partition sizes are balanced - val partSizes = rdd.partitions.map(p => p.asInstanceOf[RandomRDDPartition].size.toDouble) + val partSizes = rdd.partitions.map(p => + p.asInstanceOf[RandomRDDPartition[Double]].size.toDouble) + val partStats = new StatCounter(partSizes) assert(partStats.max - partStats.min <= 1) } @@ -89,7 +91,7 @@ class RandomRDDGeneratorsSuite extends FunSuite with LocalSparkContext with Seri val rdd = new RandomRDD(sc, size, numPartitions, new UniformGenerator, 0L) assert(rdd.partitions.size === numPartitions) val count = rdd.partitions.foldLeft(0L) { (count, part) => - count + part.asInstanceOf[RandomRDDPartition].size + count + part.asInstanceOf[RandomRDDPartition[Double]].size } assert(count === size) @@ -145,7 +147,7 @@ class RandomRDDGeneratorsSuite extends FunSuite with LocalSparkContext with Seri } } -private[random] class MockDistro extends DistributionGenerator { +private[random] class MockDistro extends RandomDataGenerator[Double] { var seed = 0L diff --git a/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala index 3f3b10dfff35e..27a19f793242b 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala @@ -46,4 +46,22 @@ class RDDFunctionsSuite extends FunSuite with LocalSparkContext { val expected = data.flatMap(x => x).sliding(3).toList assert(sliding.collect().toList === expected) } + + test("treeAggregate") { + val rdd = sc.makeRDD(-1000 until 1000, 10) + def seqOp = (c: Long, x: Int) => c + x + def combOp = (c1: Long, c2: Long) => c1 + c2 + for (depth <- 1 until 10) { + val sum = rdd.treeAggregate(0L)(seqOp, combOp, depth) + assert(sum === -1000L) + } + } + + test("treeReduce") { + val rdd = sc.makeRDD(-1000 until 1000, 10) + for (depth <- 1 until 10) { + val sum = rdd.treeReduce(_ + _, depth) + assert(sum === -1000) + } + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala index 81bebec8c7a39..017c39edb185f 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala @@ -22,11 +22,11 @@ import scala.math.abs import scala.util.Random import org.scalatest.FunSuite - import org.jblas.DoubleMatrix -import org.apache.spark.mllib.util.LocalSparkContext import org.apache.spark.SparkContext._ +import org.apache.spark.mllib.util.LocalSparkContext +import org.apache.spark.mllib.recommendation.ALS.BlockStats object ALSSuite { @@ -67,8 +67,10 @@ object ALSSuite { case true => // Generate raw values from [0,9], or if negativeWeights, from [-2,7] val raw = new DoubleMatrix(users, products, - Array.fill(users * products)((if (negativeWeights) -2 else 0) + rand.nextInt(10).toDouble): _*) - val prefs = new DoubleMatrix(users, products, raw.data.map(v => if (v > 0) 1.0 else 0.0): _*) + Array.fill(users * products)( + (if (negativeWeights) -2 else 0) + rand.nextInt(10).toDouble): _*) + val prefs = + new DoubleMatrix(users, products, raw.data.map(v => if (v > 0) 1.0 else 0.0): _*) (raw, prefs) case false => (userMatrix.mmul(productMatrix), null) } @@ -160,6 +162,22 @@ class ALSSuite extends FunSuite with LocalSparkContext { testALS(100, 200, 2, 15, 0.7, 0.4, false, false, false, -1, -1, false) } + test("analyze one user block and one product block") { + val localRatings = Seq( + Rating(0, 100, 1.0), + Rating(0, 101, 2.0), + Rating(0, 102, 3.0), + Rating(1, 102, 4.0), + Rating(2, 103, 5.0)) + val ratings = sc.makeRDD(localRatings, 2) + val stats = ALS.analyzeBlocks(ratings, 1, 1) + assert(stats.size === 2) + assert(stats(0) === BlockStats("user", 0, 3, 5, 4, 3)) + assert(stats(1) === BlockStats("product", 0, 4, 5, 3, 4)) + } + + // TODO: add tests for analyzing multiple user/product blocks + /** * Test if we can correctly factorize R = U * P where U and P are of known rank. * diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala new file mode 100644 index 0000000000000..ed21f84472c9a --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala @@ -0,0 +1,135 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.regression + +import java.io.File +import java.nio.charset.Charset + +import scala.collection.mutable.ArrayBuffer + +import com.google.common.io.Files +import org.scalatest.FunSuite + +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.util.{LinearDataGenerator, LocalSparkContext, MLUtils} +import org.apache.spark.streaming.{Milliseconds, StreamingContext} +import org.apache.spark.util.Utils + +class StreamingLinearRegressionSuite extends FunSuite with LocalSparkContext { + + // Assert that two values are equal within tolerance epsilon + def assertEqual(v1: Double, v2: Double, epsilon: Double) { + def errorMessage = v1.toString + " did not equal " + v2.toString + assert(math.abs(v1-v2) <= epsilon, errorMessage) + } + + // Assert that model predictions are correct + def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) { + val numOffPredictions = predictions.zip(input).count { case (prediction, expected) => + // A prediction is off if the prediction is more than 0.5 away from expected value. + math.abs(prediction - expected.label) > 0.5 + } + // At least 80% of the predictions should be on. + assert(numOffPredictions < input.length / 5) + } + + // Test if we can accurately learn Y = 10*X1 + 10*X2 on streaming data + test("streaming linear regression parameter accuracy") { + + val testDir = Files.createTempDir() + val numBatches = 10 + val batchDuration = Milliseconds(1000) + val ssc = new StreamingContext(sc, batchDuration) + val data = MLUtils.loadStreamingLabeledPoints(ssc, testDir.toString) + val model = new StreamingLinearRegressionWithSGD() + .setInitialWeights(Vectors.dense(0.0, 0.0)) + .setStepSize(0.1) + .setNumIterations(50) + + model.trainOn(data) + + ssc.start() + + // write data to a file stream + for (i <- 0 until numBatches) { + val samples = LinearDataGenerator.generateLinearInput( + 0.0, Array(10.0, 10.0), 100, 42 * (i + 1)) + val file = new File(testDir, i.toString) + Files.write(samples.map(x => x.toString).mkString("\n"), file, Charset.forName("UTF-8")) + Thread.sleep(batchDuration.milliseconds) + } + + ssc.stop(stopSparkContext=false) + + System.clearProperty("spark.driver.port") + Utils.deleteRecursively(testDir) + + // check accuracy of final parameter estimates + assertEqual(model.latestModel().intercept, 0.0, 0.1) + assertEqual(model.latestModel().weights(0), 10.0, 0.1) + assertEqual(model.latestModel().weights(1), 10.0, 0.1) + + // check accuracy of predictions + val validationData = LinearDataGenerator.generateLinearInput(0.0, Array(10.0, 10.0), 100, 17) + validatePrediction(validationData.map(row => model.latestModel().predict(row.features)), + validationData) + } + + // Test that parameter estimates improve when learning Y = 10*X1 on streaming data + test("streaming linear regression parameter convergence") { + + val testDir = Files.createTempDir() + val batchDuration = Milliseconds(2000) + val ssc = new StreamingContext(sc, batchDuration) + val numBatches = 5 + val data = MLUtils.loadStreamingLabeledPoints(ssc, testDir.toString) + val model = new StreamingLinearRegressionWithSGD() + .setInitialWeights(Vectors.dense(0.0)) + .setStepSize(0.1) + .setNumIterations(50) + + model.trainOn(data) + + ssc.start() + + // write data to a file stream + val history = new ArrayBuffer[Double](numBatches) + for (i <- 0 until numBatches) { + val samples = LinearDataGenerator.generateLinearInput(0.0, Array(10.0), 100, 42 * (i + 1)) + val file = new File(testDir, i.toString) + Files.write(samples.map(x => x.toString).mkString("\n"), file, Charset.forName("UTF-8")) + Thread.sleep(batchDuration.milliseconds) + // wait an extra few seconds to make sure the update finishes before new data arrive + Thread.sleep(4000) + history.append(math.abs(model.latestModel().weights(0) - 10.0)) + } + + ssc.stop(stopSparkContext=false) + + System.clearProperty("spark.driver.port") + Utils.deleteRecursively(testDir) + + val deltas = history.drop(1).zip(history.dropRight(1)) + // check error stability (it always either shrinks, or increases with small tol) + assert(deltas.forall(x => (x._1 - x._2) <= 0.1)) + // check that error shrunk on at least 2 batches + assert(deltas.map(x => if ((x._1 - x._2) < 0) 1 else 0).sum > 1) + + } + +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala index bce4251426df7..a3f76f77a5dcc 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala @@ -31,6 +31,7 @@ class CorrelationSuite extends FunSuite with LocalSparkContext { // test input data val xData = Array(1.0, 0.0, -2.0) val yData = Array(4.0, 5.0, 3.0) + val zeros = new Array[Double](3) val data = Seq( Vectors.dense(1.0, 0.0, 0.0, -2.0), Vectors.dense(4.0, 5.0, 0.0, 3.0), @@ -46,6 +47,18 @@ class CorrelationSuite extends FunSuite with LocalSparkContext { val p1 = Statistics.corr(x, y, "pearson") assert(approxEqual(expected, default)) assert(approxEqual(expected, p1)) + + // numPartitions >= size for input RDDs + for (numParts <- List(xData.size, xData.size * 2)) { + val x1 = sc.parallelize(xData, numParts) + val y1 = sc.parallelize(yData, numParts) + val p2 = Statistics.corr(x1, y1) + assert(approxEqual(expected, p2)) + } + + // RDD of zero variance + val z = sc.parallelize(zeros) + assert(Statistics.corr(x, z).isNaN()) } test("corr(x, y) spearman") { @@ -54,6 +67,18 @@ class CorrelationSuite extends FunSuite with LocalSparkContext { val expected = 0.5 val s1 = Statistics.corr(x, y, "spearman") assert(approxEqual(expected, s1)) + + // numPartitions >= size for input RDDs + for (numParts <- List(xData.size, xData.size * 2)) { + val x1 = sc.parallelize(xData, numParts) + val y1 = sc.parallelize(yData, numParts) + val s2 = Statistics.corr(x1, y1, "spearman") + assert(approxEqual(expected, s2)) + } + + // RDD of zero variance => zero variance in ranks + val z = sc.parallelize(zeros) + assert(Statistics.corr(x, z, "spearman").isNaN()) } test("corr(X) default, pearson") { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala index 4b7b019d820b4..db13f142df517 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala @@ -89,15 +89,15 @@ class MultivariateOnlineSummarizerSuite extends FunSuite { .add(Vectors.dense(-1.0, 0.0, 6.0)) .add(Vectors.dense(3.0, -3.0, 0.0)) - assert(summarizer.mean.almostEquals(Vectors.dense(1.0, -1.5, 3.0)), "mean mismatch") + assert(summarizer.mean ~== Vectors.dense(1.0, -1.5, 3.0) absTol 1E-5, "mean mismatch") - assert(summarizer.min.almostEquals(Vectors.dense(-1.0, -3, 0.0)), "min mismatch") + assert(summarizer.min ~== Vectors.dense(-1.0, -3, 0.0) absTol 1E-5, "min mismatch") - assert(summarizer.max.almostEquals(Vectors.dense(3.0, 0.0, 6.0)), "max mismatch") + assert(summarizer.max ~== Vectors.dense(3.0, 0.0, 6.0) absTol 1E-5, "max mismatch") - assert(summarizer.numNonzeros.almostEquals(Vectors.dense(2, 1, 1)), "numNonzeros mismatch") + assert(summarizer.numNonzeros ~== Vectors.dense(2, 1, 1) absTol 1E-5, "numNonzeros mismatch") - assert(summarizer.variance.almostEquals(Vectors.dense(8.0, 4.5, 18.0)), "variance mismatch") + assert(summarizer.variance ~== Vectors.dense(8.0, 4.5, 18.0) absTol 1E-5, "variance mismatch") assert(summarizer.count === 2) } @@ -107,15 +107,15 @@ class MultivariateOnlineSummarizerSuite extends FunSuite { .add(Vectors.sparse(3, Seq((0, -1.0), (2, 6.0)))) .add(Vectors.sparse(3, Seq((0, 3.0), (1, -3.0)))) - assert(summarizer.mean.almostEquals(Vectors.dense(1.0, -1.5, 3.0)), "mean mismatch") + assert(summarizer.mean ~== Vectors.dense(1.0, -1.5, 3.0) absTol 1E-5, "mean mismatch") - assert(summarizer.min.almostEquals(Vectors.dense(-1.0, -3, 0.0)), "min mismatch") + assert(summarizer.min ~== Vectors.dense(-1.0, -3, 0.0) absTol 1E-5, "min mismatch") - assert(summarizer.max.almostEquals(Vectors.dense(3.0, 0.0, 6.0)), "max mismatch") + assert(summarizer.max ~== Vectors.dense(3.0, 0.0, 6.0) absTol 1E-5, "max mismatch") - assert(summarizer.numNonzeros.almostEquals(Vectors.dense(2, 1, 1)), "numNonzeros mismatch") + assert(summarizer.numNonzeros ~== Vectors.dense(2, 1, 1) absTol 1E-5, "numNonzeros mismatch") - assert(summarizer.variance.almostEquals(Vectors.dense(8.0, 4.5, 18.0)), "variance mismatch") + assert(summarizer.variance ~== Vectors.dense(8.0, 4.5, 18.0) absTol 1E-5, "variance mismatch") assert(summarizer.count === 2) } @@ -129,17 +129,17 @@ class MultivariateOnlineSummarizerSuite extends FunSuite { .add(Vectors.dense(1.7, -0.6, 0.0)) .add(Vectors.sparse(3, Seq((1, 1.9), (2, 0.0)))) - assert(summarizer.mean.almostEquals( - Vectors.dense(0.583333333333, -0.416666666666, -0.183333333333)), "mean mismatch") + assert(summarizer.mean ~== + Vectors.dense(0.583333333333, -0.416666666666, -0.183333333333) absTol 1E-5, "mean mismatch") - assert(summarizer.min.almostEquals(Vectors.dense(-2.0, -5.1, -3)), "min mismatch") + assert(summarizer.min ~== Vectors.dense(-2.0, -5.1, -3) absTol 1E-5, "min mismatch") - assert(summarizer.max.almostEquals(Vectors.dense(3.8, 2.3, 1.9)), "max mismatch") + assert(summarizer.max ~== Vectors.dense(3.8, 2.3, 1.9) absTol 1E-5, "max mismatch") - assert(summarizer.numNonzeros.almostEquals(Vectors.dense(3, 5, 2)), "numNonzeros mismatch") + assert(summarizer.numNonzeros ~== Vectors.dense(3, 5, 2) absTol 1E-5, "numNonzeros mismatch") - assert(summarizer.variance.almostEquals( - Vectors.dense(3.857666666666, 7.0456666666666, 2.48166666666666)), "variance mismatch") + assert(summarizer.variance ~== + Vectors.dense(3.857666666666, 7.0456666666666, 2.48166666666666) absTol 1E-5, "variance mismatch") assert(summarizer.count === 6) } @@ -157,17 +157,17 @@ class MultivariateOnlineSummarizerSuite extends FunSuite { val summarizer = summarizer1.merge(summarizer2) - assert(summarizer.mean.almostEquals( - Vectors.dense(0.583333333333, -0.416666666666, -0.183333333333)), "mean mismatch") + assert(summarizer.mean ~== + Vectors.dense(0.583333333333, -0.416666666666, -0.183333333333) absTol 1E-5, "mean mismatch") - assert(summarizer.min.almostEquals(Vectors.dense(-2.0, -5.1, -3)), "min mismatch") + assert(summarizer.min ~== Vectors.dense(-2.0, -5.1, -3) absTol 1E-5, "min mismatch") - assert(summarizer.max.almostEquals(Vectors.dense(3.8, 2.3, 1.9)), "max mismatch") + assert(summarizer.max ~== Vectors.dense(3.8, 2.3, 1.9) absTol 1E-5, "max mismatch") - assert(summarizer.numNonzeros.almostEquals(Vectors.dense(3, 5, 2)), "numNonzeros mismatch") + assert(summarizer.numNonzeros ~== Vectors.dense(3, 5, 2) absTol 1E-5, "numNonzeros mismatch") - assert(summarizer.variance.almostEquals( - Vectors.dense(3.857666666666, 7.0456666666666, 2.48166666666666)), "variance mismatch") + assert(summarizer.variance ~== + Vectors.dense(3.857666666666, 7.0456666666666, 2.48166666666666) absTol 1E-5, "variance mismatch") assert(summarizer.count === 6) } @@ -186,24 +186,24 @@ class MultivariateOnlineSummarizerSuite extends FunSuite { val summarizer3 = (new MultivariateOnlineSummarizer).merge(new MultivariateOnlineSummarizer) assert(summarizer3.count === 0) - assert(summarizer1.mean.almostEquals(Vectors.dense(0.0, -1.0, -3.0)), "mean mismatch") + assert(summarizer1.mean ~== Vectors.dense(0.0, -1.0, -3.0) absTol 1E-5, "mean mismatch") - assert(summarizer2.mean.almostEquals(Vectors.dense(0.0, -1.0, -3.0)), "mean mismatch") + assert(summarizer2.mean ~== Vectors.dense(0.0, -1.0, -3.0) absTol 1E-5, "mean mismatch") - assert(summarizer1.min.almostEquals(Vectors.dense(0.0, -1.0, -3.0)), "min mismatch") + assert(summarizer1.min ~== Vectors.dense(0.0, -1.0, -3.0) absTol 1E-5, "min mismatch") - assert(summarizer2.min.almostEquals(Vectors.dense(0.0, -1.0, -3.0)), "min mismatch") + assert(summarizer2.min ~== Vectors.dense(0.0, -1.0, -3.0) absTol 1E-5, "min mismatch") - assert(summarizer1.max.almostEquals(Vectors.dense(0.0, -1.0, -3.0)), "max mismatch") + assert(summarizer1.max ~== Vectors.dense(0.0, -1.0, -3.0) absTol 1E-5, "max mismatch") - assert(summarizer2.max.almostEquals(Vectors.dense(0.0, -1.0, -3.0)), "max mismatch") + assert(summarizer2.max ~== Vectors.dense(0.0, -1.0, -3.0) absTol 1E-5, "max mismatch") - assert(summarizer1.numNonzeros.almostEquals(Vectors.dense(0, 1, 1)), "numNonzeros mismatch") + assert(summarizer1.numNonzeros ~== Vectors.dense(0, 1, 1) absTol 1E-5, "numNonzeros mismatch") - assert(summarizer2.numNonzeros.almostEquals(Vectors.dense(0, 1, 1)), "numNonzeros mismatch") + assert(summarizer2.numNonzeros ~== Vectors.dense(0, 1, 1) absTol 1E-5, "numNonzeros mismatch") - assert(summarizer1.variance.almostEquals(Vectors.dense(0, 0, 0)), "variance mismatch") + assert(summarizer1.variance ~== Vectors.dense(0, 0, 0) absTol 1E-5, "variance mismatch") - assert(summarizer2.variance.almostEquals(Vectors.dense(0, 0, 0)), "variance mismatch") + assert(summarizer2.variance ~== Vectors.dense(0, 0, 0) absTol 1E-5, "variance mismatch") } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 5961a618c59d9..70ca7c8a266f2 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -17,11 +17,12 @@ package org.apache.spark.mllib.tree +import scala.collection.JavaConverters._ + import org.scalatest.FunSuite import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance} -import org.apache.spark.mllib.tree.model.Filter -import org.apache.spark.mllib.tree.model.Split +import org.apache.spark.mllib.tree.model.{DecisionTreeModel, Filter, Split} import org.apache.spark.mllib.tree.configuration.{FeatureType, Strategy} import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.FeatureType._ @@ -31,6 +32,31 @@ import org.apache.spark.mllib.regression.LabeledPoint class DecisionTreeSuite extends FunSuite with LocalSparkContext { + def validateClassifier( + model: DecisionTreeModel, + input: Seq[LabeledPoint], + requiredAccuracy: Double) { + val predictions = input.map(x => model.predict(x.features)) + val numOffPredictions = predictions.zip(input).count { case (prediction, expected) => + prediction != expected.label + } + val accuracy = (input.length - numOffPredictions).toDouble / input.length + assert(accuracy >= requiredAccuracy) + } + + def validateRegressor( + model: DecisionTreeModel, + input: Seq[LabeledPoint], + requiredMSE: Double) { + val predictions = input.map(x => model.predict(x.features)) + val squaredError = predictions.zip(input).map { case (prediction, expected) => + val err = prediction - expected.label + err * err + }.sum + val mse = squaredError / input.length + assert(mse <= requiredMSE) + } + test("split and bin calculation") { val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1() assert(arr.length === 1000) @@ -50,7 +76,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val strategy = new Strategy( Classification, Gini, - maxDepth = 3, + maxDepth = 2, numClassesForClassification = 2, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 2, 1-> 2)) @@ -130,7 +156,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val strategy = new Strategy( Classification, Gini, - maxDepth = 3, + maxDepth = 2, numClassesForClassification = 2, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) @@ -236,7 +262,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { test("extract categories from a number for multiclass classification") { val l = DecisionTree.extractMultiClassCategories(13, 10) assert(l.length === 3) - assert(List(3.0, 2.0, 0.0).toSeq == l.toSeq) + assert(List(3.0, 2.0, 0.0).toSeq === l.toSeq) } test("split and bin calculations for unordered categorical variables with multiclass " + @@ -247,7 +273,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val strategy = new Strategy( Classification, Gini, - maxDepth = 3, + maxDepth = 2, numClassesForClassification = 100, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) @@ -341,7 +367,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val strategy = new Strategy( Classification, Gini, - maxDepth = 3, + maxDepth = 2, numClassesForClassification = 100, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 10, 1-> 10)) @@ -397,7 +423,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { Classification, Gini, numClassesForClassification = 2, - maxDepth = 3, + maxDepth = 2, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) @@ -413,7 +439,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val stats = bestSplits(0)._2 assert(stats.gain > 0) assert(stats.predict === 1) - assert(stats.prob == 0.6) + assert(stats.prob === 0.6) assert(stats.impurity > 0.2) } @@ -424,7 +450,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val strategy = new Strategy( Regression, Variance, - maxDepth = 3, + maxDepth = 2, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) val (splits, bins) = DecisionTree.findSplitsBins(rdd,strategy) @@ -439,10 +465,27 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val stats = bestSplits(0)._2 assert(stats.gain > 0) - assert(stats.predict == 0.6) + assert(stats.predict === 0.6) assert(stats.impurity > 0.2) } + test("regression stump with categorical variables of arity 2") { + val arr = DecisionTreeSuite.generateCategoricalDataPoints() + assert(arr.length === 1000) + val rdd = sc.parallelize(arr) + val strategy = new Strategy( + Regression, + Variance, + maxDepth = 2, + maxBins = 100, + categoricalFeaturesInfo = Map(0 -> 2, 1-> 2)) + + val model = DecisionTree.train(rdd, strategy) + validateRegressor(model, arr, 0.0) + assert(model.numNodes === 3) + assert(model.depth === 1) + } + test("stump with fixed label 0 for Gini") { val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel0() assert(arr.length === 1000) @@ -460,7 +503,6 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { Array[List[Filter]](), splits, bins, 10) assert(bestSplits.length === 1) assert(bestSplits(0)._1.feature === 0) - assert(bestSplits(0)._1.threshold === 10) assert(bestSplits(0)._2.gain === 0) assert(bestSplits(0)._2.leftImpurity === 0) assert(bestSplits(0)._2.rightImpurity === 0) @@ -483,7 +525,6 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { Array[List[Filter]](), splits, bins, 10) assert(bestSplits.length === 1) assert(bestSplits(0)._1.feature === 0) - assert(bestSplits(0)._1.threshold === 10) assert(bestSplits(0)._2.gain === 0) assert(bestSplits(0)._2.leftImpurity === 0) assert(bestSplits(0)._2.rightImpurity === 0) @@ -507,7 +548,6 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { Array[List[Filter]](), splits, bins, 10) assert(bestSplits.length === 1) assert(bestSplits(0)._1.feature === 0) - assert(bestSplits(0)._1.threshold === 10) assert(bestSplits(0)._2.gain === 0) assert(bestSplits(0)._2.leftImpurity === 0) assert(bestSplits(0)._2.rightImpurity === 0) @@ -531,7 +571,6 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { Array[List[Filter]](), splits, bins, 10) assert(bestSplits.length === 1) assert(bestSplits(0)._1.feature === 0) - assert(bestSplits(0)._1.threshold === 10) assert(bestSplits(0)._2.gain === 0) assert(bestSplits(0)._2.leftImpurity === 0) assert(bestSplits(0)._2.rightImpurity === 0) @@ -587,7 +626,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { test("stump with categorical variables for multiclass classification") { val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlass() val input = sc.parallelize(arr) - val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5, + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) assert(strategy.isMulticlassClassification) val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) @@ -602,12 +641,78 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bestSplit.featureType === Categorical) } + test("stump with 1 continuous variable for binary classification, to check off-by-1 error") { + val arr = new Array[LabeledPoint](4) + arr(0) = new LabeledPoint(0.0, Vectors.dense(0.0)) + arr(1) = new LabeledPoint(1.0, Vectors.dense(1.0)) + arr(2) = new LabeledPoint(1.0, Vectors.dense(2.0)) + arr(3) = new LabeledPoint(1.0, Vectors.dense(3.0)) + val input = sc.parallelize(arr) + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, + numClassesForClassification = 2) + + val model = DecisionTree.train(input, strategy) + validateClassifier(model, arr, 1.0) + assert(model.numNodes === 3) + assert(model.depth === 1) + } + + test("stump with 2 continuous variables for binary classification") { + val arr = new Array[LabeledPoint](4) + arr(0) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))) + arr(1) = new LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0)))) + arr(2) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))) + arr(3) = new LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 2.0)))) + + val input = sc.parallelize(arr) + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, + numClassesForClassification = 2) + + val model = DecisionTree.train(input, strategy) + validateClassifier(model, arr, 1.0) + assert(model.numNodes === 3) + assert(model.depth === 1) + assert(model.topNode.split.get.feature === 1) + } + + test("stump with categorical variables for multiclass classification, with just enough bins") { + val maxBins = math.pow(2, 3 - 1).toInt // just enough bins to allow unordered features + val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlass() + val input = sc.parallelize(arr) + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, + numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) + assert(strategy.isMulticlassClassification) + + val model = DecisionTree.train(input, strategy) + validateClassifier(model, arr, 1.0) + assert(model.numNodes === 3) + assert(model.depth === 1) + + val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) + val bestSplits = DecisionTree.findBestSplits(input, new Array(31), strategy, 0, + Array[List[Filter]](), splits, bins, 10) + + assert(bestSplits.length === 1) + val bestSplit = bestSplits(0)._1 + assert(bestSplit.feature === 0) + assert(bestSplit.categories.length === 1) + assert(bestSplit.categories.contains(1)) + assert(bestSplit.featureType === Categorical) + val gain = bestSplits(0)._2 + assert(gain.leftImpurity === 0) + assert(gain.rightImpurity === 0) + } + test("stump with continuous variables for multiclass classification") { val arr = DecisionTreeSuite.generateContinuousDataPointsForMulticlass() val input = sc.parallelize(arr) - val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5, + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, numClassesForClassification = 3) assert(strategy.isMulticlassClassification) + + val model = DecisionTree.train(input, strategy) + validateClassifier(model, arr, 0.9) + val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) val bestSplits = DecisionTree.findBestSplits(input, new Array(31), strategy, 0, Array[List[Filter]](), splits, bins, 10) @@ -625,9 +730,13 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { test("stump with continuous + categorical variables for multiclass classification") { val arr = DecisionTreeSuite.generateContinuousDataPointsForMulticlass() val input = sc.parallelize(arr) - val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5, + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 3)) assert(strategy.isMulticlassClassification) + + val model = DecisionTree.train(input, strategy) + validateClassifier(model, arr, 0.9) + val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) val bestSplits = DecisionTree.findBestSplits(input, new Array(31), strategy, 0, Array[List[Filter]](), splits, bins, 10) @@ -644,7 +753,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { test("stump with categorical variables for ordered multiclass classification") { val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures() val input = sc.parallelize(arr) - val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5, + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 10, 1 -> 10)) assert(strategy.isMulticlassClassification) val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) @@ -708,6 +817,10 @@ object DecisionTreeSuite { arr } + def generateCategoricalDataPointsAsJavaList(): java.util.List[LabeledPoint] = { + generateCategoricalDataPoints().toList.asJava + } + def generateCategoricalDataPointsForMulticlass(): Array[LabeledPoint] = { val arr = new Array[LabeledPoint](3000) for (i <- 0 until 3000) { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala index c14870fb969a8..8ef2bb1bf6a78 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala @@ -63,9 +63,9 @@ class MLUtilsSuite extends FunSuite with LocalSparkContext { test("loadLibSVMFile") { val lines = """ - |+1 1:1.0 3:2.0 5:3.0 - |-1 - |-1 2:4.0 4:5.0 6:6.0 + |1 1:1.0 3:2.0 5:3.0 + |0 + |0 2:4.0 4:5.0 6:6.0 """.stripMargin val tempDir = Files.createTempDir() tempDir.deleteOnExit() @@ -73,7 +73,7 @@ class MLUtilsSuite extends FunSuite with LocalSparkContext { Files.write(lines, file, Charsets.US_ASCII) val path = tempDir.toURI.toString - val pointsWithNumFeatures = loadLibSVMFile(sc, path, multiclass = false, 6).collect() + val pointsWithNumFeatures = loadLibSVMFile(sc, path, 6).collect() val pointsWithoutNumFeatures = loadLibSVMFile(sc, path).collect() for (points <- Seq(pointsWithNumFeatures, pointsWithoutNumFeatures)) { @@ -86,11 +86,11 @@ class MLUtilsSuite extends FunSuite with LocalSparkContext { assert(points(2).features === Vectors.sparse(6, Seq((1, 4.0), (3, 5.0), (5, 6.0)))) } - val multiclassPoints = loadLibSVMFile(sc, path, multiclass = true).collect() + val multiclassPoints = loadLibSVMFile(sc, path).collect() assert(multiclassPoints.length === 3) assert(multiclassPoints(0).label === 1.0) - assert(multiclassPoints(1).label === -1.0) - assert(multiclassPoints(2).label === -1.0) + assert(multiclassPoints(1).label === 0.0) + assert(multiclassPoints(2).label === 0.0) Utils.deleteRecursively(tempDir) } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtils.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtils.scala index 64b1ba7527183..29cc42d8cbea7 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtils.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtils.scala @@ -18,28 +18,155 @@ package org.apache.spark.mllib.util import org.apache.spark.mllib.linalg.Vector +import org.scalatest.exceptions.TestFailedException object TestingUtils { + val ABS_TOL_MSG = " using absolute tolerance" + val REL_TOL_MSG = " using relative tolerance" + + /** + * Private helper function for comparing two values using relative tolerance. + * Note that if x or y is extremely close to zero, i.e., smaller than Double.MinPositiveValue, + * the relative tolerance is meaningless, so the exception will be raised to warn users. + */ + private def RelativeErrorComparison(x: Double, y: Double, eps: Double): Boolean = { + val absX = math.abs(x) + val absY = math.abs(y) + val diff = math.abs(x - y) + if (x == y) { + true + } else if (absX < Double.MinPositiveValue || absY < Double.MinPositiveValue) { + throw new TestFailedException( + s"$x or $y is extremely close to zero, so the relative tolerance is meaningless.", 0) + } else { + diff < eps * math.min(absX, absY) + } + } + + /** + * Private helper function for comparing two values using absolute tolerance. + */ + private def AbsoluteErrorComparison(x: Double, y: Double, eps: Double): Boolean = { + math.abs(x - y) < eps + } + + case class CompareDoubleRightSide( + fun: (Double, Double, Double) => Boolean, y: Double, eps: Double, method: String) + + /** + * Implicit class for comparing two double values using relative tolerance or absolute tolerance. + */ implicit class DoubleWithAlmostEquals(val x: Double) { - // An improved version of AlmostEquals would always divide by the larger number. - // This will avoid the problem of diving by zero. - def almostEquals(y: Double, epsilon: Double = 1E-10): Boolean = { - if(x == y) { - true - } else if(math.abs(x) > math.abs(y)) { - math.abs(x - y) / math.abs(x) < epsilon - } else { - math.abs(x - y) / math.abs(y) < epsilon + + /** + * When the difference of two values are within eps, returns true; otherwise, returns false. + */ + def ~=(r: CompareDoubleRightSide): Boolean = r.fun(x, r.y, r.eps) + + /** + * When the difference of two values are within eps, returns false; otherwise, returns true. + */ + def !~=(r: CompareDoubleRightSide): Boolean = !r.fun(x, r.y, r.eps) + + /** + * Throws exception when the difference of two values are NOT within eps; + * otherwise, returns true. + */ + def ~==(r: CompareDoubleRightSide): Boolean = { + if (!r.fun(x, r.y, r.eps)) { + throw new TestFailedException( + s"Expected $x and ${r.y} to be within ${r.eps}${r.method}.", 0) } + true } + + /** + * Throws exception when the difference of two values are within eps; otherwise, returns true. + */ + def !~==(r: CompareDoubleRightSide): Boolean = { + if (r.fun(x, r.y, r.eps)) { + throw new TestFailedException( + s"Did not expect $x and ${r.y} to be within ${r.eps}${r.method}.", 0) + } + true + } + + /** + * Comparison using absolute tolerance. + */ + def absTol(eps: Double): CompareDoubleRightSide = CompareDoubleRightSide(AbsoluteErrorComparison, + x, eps, ABS_TOL_MSG) + + /** + * Comparison using relative tolerance. + */ + def relTol(eps: Double): CompareDoubleRightSide = CompareDoubleRightSide(RelativeErrorComparison, + x, eps, REL_TOL_MSG) + + override def toString = x.toString } + case class CompareVectorRightSide( + fun: (Vector, Vector, Double) => Boolean, y: Vector, eps: Double, method: String) + + /** + * Implicit class for comparing two vectors using relative tolerance or absolute tolerance. + */ implicit class VectorWithAlmostEquals(val x: Vector) { - def almostEquals(y: Vector, epsilon: Double = 1E-10): Boolean = { - x.toArray.corresponds(y.toArray) { - _.almostEquals(_, epsilon) + + /** + * When the difference of two vectors are within eps, returns true; otherwise, returns false. + */ + def ~=(r: CompareVectorRightSide): Boolean = r.fun(x, r.y, r.eps) + + /** + * When the difference of two vectors are within eps, returns false; otherwise, returns true. + */ + def !~=(r: CompareVectorRightSide): Boolean = !r.fun(x, r.y, r.eps) + + /** + * Throws exception when the difference of two vectors are NOT within eps; + * otherwise, returns true. + */ + def ~==(r: CompareVectorRightSide): Boolean = { + if (!r.fun(x, r.y, r.eps)) { + throw new TestFailedException( + s"Expected $x and ${r.y} to be within ${r.eps}${r.method} for all elements.", 0) } + true } + + /** + * Throws exception when the difference of two vectors are within eps; otherwise, returns true. + */ + def !~==(r: CompareVectorRightSide): Boolean = { + if (r.fun(x, r.y, r.eps)) { + throw new TestFailedException( + s"Did not expect $x and ${r.y} to be within ${r.eps}${r.method} for all elements.", 0) + } + true + } + + /** + * Comparison using absolute tolerance. + */ + def absTol(eps: Double): CompareVectorRightSide = CompareVectorRightSide( + (x: Vector, y: Vector, eps: Double) => { + x.toArray.zip(y.toArray).forall(x => x._1 ~= x._2 absTol eps) + }, x, eps, ABS_TOL_MSG) + + /** + * Comparison using relative tolerance. Note that comparing against sparse vector + * with elements having value of zero will raise exception because it involves with + * comparing against zero. + */ + def relTol(eps: Double): CompareVectorRightSide = CompareVectorRightSide( + (x: Vector, y: Vector, eps: Double) => { + x.toArray.zip(y.toArray).forall(x => x._1 ~= x._2 relTol eps) + }, x, eps, REL_TOL_MSG) + + override def toString = x.toString } + } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtilsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtilsSuite.scala new file mode 100644 index 0000000000000..b0ecb33c28483 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtilsSuite.scala @@ -0,0 +1,182 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.util + +import org.apache.spark.mllib.linalg.Vectors +import org.scalatest.FunSuite +import org.apache.spark.mllib.util.TestingUtils._ +import org.scalatest.exceptions.TestFailedException + +class TestingUtilsSuite extends FunSuite { + + test("Comparing doubles using relative error.") { + + assert(23.1 ~== 23.52 relTol 0.02) + assert(23.1 ~== 22.74 relTol 0.02) + assert(23.1 ~= 23.52 relTol 0.02) + assert(23.1 ~= 22.74 relTol 0.02) + assert(!(23.1 !~= 23.52 relTol 0.02)) + assert(!(23.1 !~= 22.74 relTol 0.02)) + + // Should throw exception with message when test fails. + intercept[TestFailedException](23.1 !~== 23.52 relTol 0.02) + intercept[TestFailedException](23.1 !~== 22.74 relTol 0.02) + intercept[TestFailedException](23.1 ~== 23.63 relTol 0.02) + intercept[TestFailedException](23.1 ~== 22.34 relTol 0.02) + + assert(23.1 !~== 23.63 relTol 0.02) + assert(23.1 !~== 22.34 relTol 0.02) + assert(23.1 !~= 23.63 relTol 0.02) + assert(23.1 !~= 22.34 relTol 0.02) + assert(!(23.1 ~= 23.63 relTol 0.02)) + assert(!(23.1 ~= 22.34 relTol 0.02)) + + // Comparing against zero should fail the test and throw exception with message + // saying that the relative error is meaningless in this situation. + intercept[TestFailedException](0.1 ~== 0.0 relTol 0.032) + intercept[TestFailedException](0.1 ~= 0.0 relTol 0.032) + intercept[TestFailedException](0.1 !~== 0.0 relTol 0.032) + intercept[TestFailedException](0.1 !~= 0.0 relTol 0.032) + intercept[TestFailedException](0.0 ~== 0.1 relTol 0.032) + intercept[TestFailedException](0.0 ~= 0.1 relTol 0.032) + intercept[TestFailedException](0.0 !~== 0.1 relTol 0.032) + intercept[TestFailedException](0.0 !~= 0.1 relTol 0.032) + + // Comparisons of numbers very close to zero. + assert(10 * Double.MinPositiveValue ~== 9.5 * Double.MinPositiveValue relTol 0.01) + assert(10 * Double.MinPositiveValue !~== 11 * Double.MinPositiveValue relTol 0.01) + + assert(-Double.MinPositiveValue ~== 1.18 * -Double.MinPositiveValue relTol 0.012) + assert(-Double.MinPositiveValue ~== 1.38 * -Double.MinPositiveValue relTol 0.012) + } + + test("Comparing doubles using absolute error.") { + + assert(17.8 ~== 17.99 absTol 0.2) + assert(17.8 ~== 17.61 absTol 0.2) + assert(17.8 ~= 17.99 absTol 0.2) + assert(17.8 ~= 17.61 absTol 0.2) + assert(!(17.8 !~= 17.99 absTol 0.2)) + assert(!(17.8 !~= 17.61 absTol 0.2)) + + // Should throw exception with message when test fails. + intercept[TestFailedException](17.8 !~== 17.99 absTol 0.2) + intercept[TestFailedException](17.8 !~== 17.61 absTol 0.2) + intercept[TestFailedException](17.8 ~== 18.01 absTol 0.2) + intercept[TestFailedException](17.8 ~== 17.59 absTol 0.2) + + assert(17.8 !~== 18.01 absTol 0.2) + assert(17.8 !~== 17.59 absTol 0.2) + assert(17.8 !~= 18.01 absTol 0.2) + assert(17.8 !~= 17.59 absTol 0.2) + assert(!(17.8 ~= 18.01 absTol 0.2)) + assert(!(17.8 ~= 17.59 absTol 0.2)) + + // Comparisons of numbers very close to zero, and both side of zeros + assert(Double.MinPositiveValue ~== 4 * Double.MinPositiveValue absTol 5 * Double.MinPositiveValue) + assert(Double.MinPositiveValue !~== 6 * Double.MinPositiveValue absTol 5 * Double.MinPositiveValue) + + assert(-Double.MinPositiveValue ~== 3 * Double.MinPositiveValue absTol 5 * Double.MinPositiveValue) + assert(Double.MinPositiveValue !~== -4 * Double.MinPositiveValue absTol 5 * Double.MinPositiveValue) + } + + test("Comparing vectors using relative error.") { + + //Comparisons of two dense vectors + assert(Vectors.dense(Array(3.1, 3.5)) ~== Vectors.dense(Array(3.130, 3.534)) relTol 0.01) + assert(Vectors.dense(Array(3.1, 3.5)) !~== Vectors.dense(Array(3.135, 3.534)) relTol 0.01) + assert(Vectors.dense(Array(3.1, 3.5)) ~= Vectors.dense(Array(3.130, 3.534)) relTol 0.01) + assert(Vectors.dense(Array(3.1, 3.5)) !~= Vectors.dense(Array(3.135, 3.534)) relTol 0.01) + assert(!(Vectors.dense(Array(3.1, 3.5)) !~= Vectors.dense(Array(3.130, 3.534)) relTol 0.01)) + assert(!(Vectors.dense(Array(3.1, 3.5)) ~= Vectors.dense(Array(3.135, 3.534)) relTol 0.01)) + + // Should throw exception with message when test fails. + intercept[TestFailedException]( + Vectors.dense(Array(3.1, 3.5)) !~== Vectors.dense(Array(3.130, 3.534)) relTol 0.01) + + intercept[TestFailedException]( + Vectors.dense(Array(3.1, 3.5)) ~== Vectors.dense(Array(3.135, 3.534)) relTol 0.01) + + // Comparing against zero should fail the test and throw exception with message + // saying that the relative error is meaningless in this situation. + intercept[TestFailedException]( + Vectors.dense(Array(3.1, 0.01)) ~== Vectors.dense(Array(3.13, 0.0)) relTol 0.01) + + intercept[TestFailedException]( + Vectors.dense(Array(3.1, 0.01)) ~== Vectors.sparse(2, Array(0), Array(3.13)) relTol 0.01) + + // Comparisons of two sparse vectors + assert(Vectors.dense(Array(3.1, 3.5)) ~== + Vectors.sparse(2, Array(0, 1), Array(3.130, 3.534)) relTol 0.01) + + assert(Vectors.dense(Array(3.1, 3.5)) !~== + Vectors.sparse(2, Array(0, 1), Array(3.135, 3.534)) relTol 0.01) + } + + test("Comparing vectors using absolute error.") { + + //Comparisons of two dense vectors + assert(Vectors.dense(Array(3.1, 3.5, 0.0)) ~== + Vectors.dense(Array(3.1 + 1E-8, 3.5 + 2E-7, 1E-8)) absTol 1E-6) + + assert(Vectors.dense(Array(3.1, 3.5, 0.0)) !~== + Vectors.dense(Array(3.1 + 1E-5, 3.5 + 2E-7, 1 + 1E-3)) absTol 1E-6) + + assert(Vectors.dense(Array(3.1, 3.5, 0.0)) ~= + Vectors.dense(Array(3.1 + 1E-8, 3.5 + 2E-7, 1E-8)) absTol 1E-6) + + assert(Vectors.dense(Array(3.1, 3.5, 0.0)) !~= + Vectors.dense(Array(3.1 + 1E-5, 3.5 + 2E-7, 1 + 1E-3)) absTol 1E-6) + + assert(!(Vectors.dense(Array(3.1, 3.5, 0.0)) !~= + Vectors.dense(Array(3.1 + 1E-8, 3.5 + 2E-7, 1E-8)) absTol 1E-6)) + + assert(!(Vectors.dense(Array(3.1, 3.5, 0.0)) ~= + Vectors.dense(Array(3.1 + 1E-5, 3.5 + 2E-7, 1 + 1E-3)) absTol 1E-6)) + + // Should throw exception with message when test fails. + intercept[TestFailedException](Vectors.dense(Array(3.1, 3.5, 0.0)) !~== + Vectors.dense(Array(3.1 + 1E-8, 3.5 + 2E-7, 1E-8)) absTol 1E-6) + + intercept[TestFailedException](Vectors.dense(Array(3.1, 3.5, 0.0)) ~== + Vectors.dense(Array(3.1 + 1E-5, 3.5 + 2E-7, 1 + 1E-3)) absTol 1E-6) + + // Comparisons of two sparse vectors + assert(Vectors.sparse(3, Array(0, 2), Array(3.1, 2.4)) ~== + Vectors.sparse(3, Array(0, 2), Array(3.1 + 1E-8, 2.4 + 1E-7)) absTol 1E-6) + + assert(Vectors.sparse(3, Array(0, 2), Array(3.1 + 1E-8, 2.4 + 1E-7)) ~== + Vectors.sparse(3, Array(0, 2), Array(3.1, 2.4)) absTol 1E-6) + + assert(Vectors.sparse(3, Array(0, 2), Array(3.1, 2.4)) !~== + Vectors.sparse(3, Array(0, 2), Array(3.1 + 1E-3, 2.4)) absTol 1E-6) + + assert(Vectors.sparse(3, Array(0, 2), Array(3.1 + 1E-3, 2.4)) !~== + Vectors.sparse(3, Array(0, 2), Array(3.1, 2.4)) absTol 1E-6) + + // Comparisons of a dense vector and a sparse vector + assert(Vectors.sparse(3, Array(0, 2), Array(3.1, 2.4)) ~== + Vectors.dense(Array(3.1 + 1E-8, 0, 2.4 + 1E-7)) absTol 1E-6) + + assert(Vectors.dense(Array(3.1 + 1E-8, 0, 2.4 + 1E-7)) ~== + Vectors.sparse(3, Array(0, 2), Array(3.1, 2.4)) absTol 1E-6) + + assert(Vectors.sparse(3, Array(0, 2), Array(3.1, 2.4)) !~== + Vectors.dense(Array(3.1, 1E-3, 2.4)) absTol 1E-6) + } +} diff --git a/pom.xml b/pom.xml index d2e6b3c0ed5a4..4ab027bad55c0 100644 --- a/pom.xml +++ b/pom.xml @@ -100,6 +100,7 @@ external/twitter external/kafka external/flume + external/flume-sink external/zeromq external/mqtt examples @@ -113,6 +114,7 @@ spark 2.10.4 2.10 + 2.0.1 0.18.1 shaded-protobuf org.spark-project.akka @@ -132,6 +134,8 @@ 3.0.0 1.7.6 0.7.1 + 1.8.3 + 1.1.0 64m 512m @@ -252,9 +256,15 @@ 3.3.2 - commons-codec - commons-codec - 1.5 + commons-codec + commons-codec + 1.5 + + + org.apache.commons + commons-math3 + 3.3 + test com.google.code.findbugs @@ -772,7 +782,7 @@ net.alchim31.maven scala-maven-plugin - 3.1.6 + 3.2.0 scala-compile-first @@ -818,6 +828,15 @@ -target ${java.version} + + + + org.scalamacros + paradise_${scala.version} + ${scala.macros.version} + + @@ -851,10 +870,11 @@ ${project.build.directory}/SparkTestSuite.txt -Xmx3g -XX:MaxPermSize=${MaxPermGen} -XX:ReservedCodeCacheSize=512m - - ${session.executionRootDirectory} - 1 - + + true + ${session.executionRootDirectory} + 1 + @@ -994,6 +1014,14 @@ + + + kinesis-asl + + extras/kinesis-asl + + + java8-tests @@ -1139,5 +1167,15 @@ + + hive-thriftserver + + false + + + sql/hive-thriftserver + + + diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 5ff88f0dd1cac..537ca0dcf267d 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -71,7 +71,12 @@ object MimaExcludes { "org.apache.spark.storage.TachyonStore.putValues") ) ++ Seq( - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.flume.FlumeReceiver.this") + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.streaming.flume.FlumeReceiver.this"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "org.apache.spark.streaming.kafka.KafkaUtils.createStream"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "org.apache.spark.streaming.kafka.KafkaReceiver.this") ) ++ Seq( // Ignore some private methods in ALS. ProblemFilters.exclude[MissingMethodProblem]( @@ -97,6 +102,14 @@ object MimaExcludes { "org.apache.spark.mllib.tree.impurity.Entropy.calculate"), ProblemFilters.exclude[IncompatibleMethTypeProblem]( "org.apache.spark.mllib.tree.impurity.Variance.calculate") + ) ++ + Seq ( // Package-private classes removed in SPARK-2341 + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.util.BinaryLabelParser"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.util.BinaryLabelParser$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.util.LabelParser"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.util.LabelParser$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.util.MulticlassLabelParser"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.util.MulticlassLabelParser$") ) case v if v.startsWith("1.0") => Seq( diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 62576f84dd031..40b588512ff08 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -30,14 +30,15 @@ object BuildCommons { private val buildLocation = file(".").getAbsoluteFile.getParentFile - val allProjects@Seq(bagel, catalyst, core, graphx, hive, mllib, repl, spark, sql, streaming, - streamingFlume, streamingKafka, streamingMqtt, streamingTwitter, streamingZeromq) = - Seq("bagel", "catalyst", "core", "graphx", "hive", "mllib", "repl", "spark", "sql", - "streaming", "streaming-flume", "streaming-kafka", "streaming-mqtt", "streaming-twitter", - "streaming-zeromq").map(ProjectRef(buildLocation, _)) - - val optionallyEnabledProjects@Seq(yarn, yarnStable, yarnAlpha, java8Tests, sparkGangliaLgpl) = - Seq("yarn", "yarn-stable", "yarn-alpha", "java8-tests", "ganglia-lgpl") + val allProjects@Seq(bagel, catalyst, core, graphx, hive, hiveThriftServer, mllib, repl, spark, + sql, streaming, streamingFlumeSink, streamingFlume, streamingKafka, streamingMqtt, + streamingTwitter, streamingZeromq) = + Seq("bagel", "catalyst", "core", "graphx", "hive", "hive-thriftserver", "mllib", "repl", + "spark", "sql", "streaming", "streaming-flume-sink", "streaming-flume", "streaming-kafka", + "streaming-mqtt", "streaming-twitter", "streaming-zeromq").map(ProjectRef(buildLocation, _)) + + val optionallyEnabledProjects@Seq(yarn, yarnStable, yarnAlpha, java8Tests, sparkGangliaLgpl, sparkKinesisAsl) = + Seq("yarn", "yarn-stable", "yarn-alpha", "java8-tests", "ganglia-lgpl", "kinesis-asl") .map(ProjectRef(buildLocation, _)) val assemblyProjects@Seq(assembly, examples) = Seq("assembly", "examples") @@ -61,7 +62,7 @@ object SparkBuild extends PomBuild { var isAlphaYarn = false var profiles: mutable.Seq[String] = mutable.Seq.empty if (Properties.envOrNone("SPARK_GANGLIA_LGPL").isDefined) { - println("NOTE: SPARK_GANGLIA_LGPL is deprecated, please use -Pganglia-lgpl flag.") + println("NOTE: SPARK_GANGLIA_LGPL is deprecated, please use -Pspark-ganglia-lgpl flag.") profiles ++= Seq("spark-ganglia-lgpl") } if (Properties.envOrNone("SPARK_HIVE").isDefined) { @@ -100,7 +101,7 @@ object SparkBuild extends PomBuild { Properties.envOrNone("SBT_MAVEN_PROPERTIES") match { case Some(v) => v.split("(\\s+|,)").filterNot(_.isEmpty).map(_.split("=")).foreach(x => System.setProperty(x(0), x(1))) - case _ => + case _ => } override val userPropertiesMap = System.getProperties.toMap @@ -156,10 +157,9 @@ object SparkBuild extends PomBuild { /* Enable tests settings for all projects except examples, assembly and tools */ (allProjects ++ optionallyEnabledProjects).foreach(enable(TestSettings.settings)) - /* Enable Mima for all projects except spark, hive, catalyst, sql and repl */ // TODO: Add Sql to mima checks - allProjects.filterNot(y => Seq(spark, sql, hive, catalyst, repl).exists(x => x == y)). - foreach (x => enable(MimaBuild.mimaSettings(sparkHome, x))(x)) + allProjects.filterNot(x => Seq(spark, sql, hive, hiveThriftServer, catalyst, repl, + streamingFlumeSink).contains(x)).foreach(x => enable(MimaBuild.mimaSettings(sparkHome, x))(x)) /* Enable Assembly for all assembly projects */ assemblyProjects.foreach(enable(Assembly.settings)) @@ -167,12 +167,17 @@ object SparkBuild extends PomBuild { /* Enable unidoc only for the root spark project */ enable(Unidoc.settings)(spark) + /* Catalyst macro settings */ + enable(Catalyst.settings)(catalyst) + /* Spark SQL Core console settings */ enable(SQL.settings)(sql) /* Hive console settings */ enable(Hive.settings)(hive) + enable(Flume.settings)(streamingFlumeSink) + // TODO: move this to its upstream project. override def projectDefinitions(baseDirectory: File): Seq[Project] = { super.projectDefinitions(baseDirectory).map { x => @@ -183,10 +188,20 @@ object SparkBuild extends PomBuild { } -object SQL { +object Flume { + lazy val settings = sbtavro.SbtAvro.avroSettings +} +object Catalyst { lazy val settings = Seq( + addCompilerPlugin("org.scalamacros" % "paradise" % "2.0.1" cross CrossVersion.full), + // Quasiquotes break compiling scala doc... + // TODO: Investigate fixing this. + sources in (Compile, doc) ~= (_ filter (_.getName contains "codegen"))) +} +object SQL { + lazy val settings = Seq( initialCommands in console := """ |import org.apache.spark.sql.catalyst.analysis._ @@ -201,7 +216,6 @@ object SQL { |import org.apache.spark.sql.test.TestSQLContext._ |import org.apache.spark.sql.parquet.ParquetTestData""".stripMargin ) - } object Hive { @@ -281,6 +295,7 @@ object Unidoc { .map(_.filterNot(_.getCanonicalPath.contains("akka"))) .map(_.filterNot(_.getCanonicalPath.contains("deploy"))) .map(_.filterNot(_.getCanonicalPath.contains("network"))) + .map(_.filterNot(_.getCanonicalPath.contains("shuffle"))) .map(_.filterNot(_.getCanonicalPath.contains("executor"))) .map(_.filterNot(_.getCanonicalPath.contains("python"))) .map(_.filterNot(_.getCanonicalPath.contains("collection"))) @@ -301,7 +316,7 @@ object Unidoc { "mllib.regression", "mllib.stat", "mllib.tree", "mllib.tree.configuration", "mllib.tree.impurity", "mllib.tree.model", "mllib.util" ), - "-group", "Spark SQL", packageList("sql.api.java", "sql.hive.api.java"), + "-group", "Spark SQL", packageList("sql.api.java", "sql.api.java.types", "sql.hive.api.java"), "-noqualifier", "java.lang" ) ) @@ -313,8 +328,10 @@ object TestSettings { lazy val settings = Seq ( // Fork new JVMs for tests and set Java options for those fork := true, - javaOptions in Test += "-Dspark.home=" + sparkHome, + javaOptions in Test += "-Dspark.test.home=" + sparkHome, javaOptions in Test += "-Dspark.testing=1", + javaOptions in Test += "-Dspark.ports.maxRetries=100", + javaOptions in Test += "-Dspark.ui.port=0", javaOptions in Test += "-Dsun.io.serialization.extendedDebugInfo=true", javaOptions in Test ++= System.getProperties.filter(_._1 startsWith "spark") .map { case (k,v) => s"-D$k=$v" }.toSeq, diff --git a/project/plugins.sbt b/project/plugins.sbt index d3ac4bf335e87..06d18e193076e 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -24,3 +24,5 @@ addSbtPlugin("com.typesafe" % "sbt-mima-plugin" % "0.1.6") addSbtPlugin("com.alpinenow" % "junit_xml_listener" % "0.5.1") addSbtPlugin("com.eed3si9n" % "sbt-unidoc" % "0.3.0") + +addSbtPlugin("com.cavorite" % "sbt-avro" % "0.3.2") diff --git a/python/lib/py4j-0.8.1-src.zip b/python/lib/py4j-0.8.1-src.zip deleted file mode 100644 index 2069a328d1f2e..0000000000000 Binary files a/python/lib/py4j-0.8.1-src.zip and /dev/null differ diff --git a/python/lib/py4j-0.8.2.1-src.zip b/python/lib/py4j-0.8.2.1-src.zip new file mode 100644 index 0000000000000..5203b84d9119e Binary files /dev/null and b/python/lib/py4j-0.8.2.1-src.zip differ diff --git a/python/pyspark/__init__.py b/python/pyspark/__init__.py index 312c75d112cbf..c58555fc9d2c5 100644 --- a/python/pyspark/__init__.py +++ b/python/pyspark/__init__.py @@ -49,6 +49,16 @@ Main entry point for accessing data stored in Apache Hive.. """ +# The following block allows us to import python's random instead of mllib.random for scripts in +# mllib that depend on top level pyspark packages, which transitively depend on python's random. +# Since Python's import logic looks for modules in the current package first, we eliminate +# mllib.random as a candidate for C{import random} by removing the first search path, the script's +# location, in order to force the loader to look in Python's top-level modules for C{random}. +import sys +s = sys.path.pop(0) +import random +sys.path.insert(0, s) + from pyspark.conf import SparkConf from pyspark.context import SparkContext from pyspark.sql import SQLContext diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py index 2204e9c9ca701..45d36e5d0e764 100644 --- a/python/pyspark/accumulators.py +++ b/python/pyspark/accumulators.py @@ -86,6 +86,7 @@ Exception:... """ +import select import struct import SocketServer import threading @@ -209,19 +210,38 @@ def addInPlace(self, value1, value2): class _UpdateRequestHandler(SocketServer.StreamRequestHandler): + """ + This handler will keep polling updates from the same socket until the + server is shutdown. + """ + def handle(self): from pyspark.accumulators import _accumulatorRegistry - num_updates = read_int(self.rfile) - for _ in range(num_updates): - (aid, update) = pickleSer._read_with_length(self.rfile) - _accumulatorRegistry[aid] += update - # Write a byte in acknowledgement - self.wfile.write(struct.pack("!b", 1)) + while not self.server.server_shutdown: + # Poll every 1 second for new data -- don't block in case of shutdown. + r, _, _ = select.select([self.rfile], [], [], 1) + if self.rfile in r: + num_updates = read_int(self.rfile) + for _ in range(num_updates): + (aid, update) = pickleSer._read_with_length(self.rfile) + _accumulatorRegistry[aid] += update + # Write a byte in acknowledgement + self.wfile.write(struct.pack("!b", 1)) + +class AccumulatorServer(SocketServer.TCPServer): + """ + A simple TCP server that intercepts shutdown() in order to interrupt + our continuous polling on the handler. + """ + server_shutdown = False + def shutdown(self): + self.server_shutdown = True + SocketServer.TCPServer.shutdown(self) def _start_update_server(): """Start a TCP server to receive accumulator updates in a daemon thread, and returns it""" - server = SocketServer.TCPServer(("localhost", 0), _UpdateRequestHandler) + server = AccumulatorServer(("localhost", 0), _UpdateRequestHandler) thread = threading.Thread(target=server.serve_forever) thread.daemon = True thread.start() diff --git a/python/pyspark/cloudpickle.py b/python/pyspark/cloudpickle.py index 4fda2a9b950b8..68062483dedaa 100644 --- a/python/pyspark/cloudpickle.py +++ b/python/pyspark/cloudpickle.py @@ -560,8 +560,9 @@ class ItemGetterType(ctypes.Structure): ] - itemgetter_obj = ctypes.cast(ctypes.c_void_p(id(obj)), ctypes.POINTER(ItemGetterType)).contents - return self.save_reduce(operator.itemgetter, (itemgetter_obj.item,)) + obj = ctypes.cast(ctypes.c_void_p(id(obj)), ctypes.POINTER(ItemGetterType)).contents + return self.save_reduce(operator.itemgetter, + obj.item if obj.nitems > 1 else (obj.item,)) if PyObject_HEAD: dispatch[operator.itemgetter] = save_itemgetter diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 830a6ee03f2a6..2e80eb50f2207 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -60,6 +60,7 @@ class SparkContext(object): _active_spark_context = None _lock = Lock() _python_includes = None # zip and egg files that need to be added to PYTHONPATH + _default_batch_size_for_serialized_input = 10 def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None, environment=None, batchSize=1024, serializer=PickleSerializer(), conf=None, @@ -83,7 +84,7 @@ def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None, @param serializer: The serializer for RDDs. @param conf: A L{SparkConf} object setting Spark properties. @param gateway: Use an existing gateway and JVM, otherwise a new JVM - will be instatiated. + will be instantiated. >>> from pyspark.context import SparkContext @@ -378,7 +379,7 @@ def _dictToJavaMap(self, d): return jm def sequenceFile(self, path, keyClass=None, valueClass=None, keyConverter=None, - valueConverter=None, minSplits=None): + valueConverter=None, minSplits=None, batchSize=None): """ Read a Hadoop SequenceFile with arbitrary key and value Writable class from HDFS, a local file system (available on all nodes), or any Hadoop-supported file system URI. @@ -398,14 +399,18 @@ def sequenceFile(self, path, keyClass=None, valueClass=None, keyConverter=None, @param valueConverter: @param minSplits: minimum splits in dataset (default min(2, sc.defaultParallelism)) + @param batchSize: The number of Python objects represented as a single + Java object. (default sc._default_batch_size_for_serialized_input) """ minSplits = minSplits or min(self.defaultParallelism, 2) + batchSize = max(1, batchSize or self._default_batch_size_for_serialized_input) + ser = BatchedSerializer(PickleSerializer()) if (batchSize > 1) else PickleSerializer() jrdd = self._jvm.PythonRDD.sequenceFile(self._jsc, path, keyClass, valueClass, - keyConverter, valueConverter, minSplits) - return RDD(jrdd, self, PickleSerializer()) + keyConverter, valueConverter, minSplits, batchSize) + return RDD(jrdd, self, ser) def newAPIHadoopFile(self, path, inputFormatClass, keyClass, valueClass, keyConverter=None, - valueConverter=None, conf=None): + valueConverter=None, conf=None, batchSize=None): """ Read a 'new API' Hadoop InputFormat with arbitrary key and value class from HDFS, a local file system (available on all nodes), or any Hadoop-supported file system URI. @@ -425,14 +430,18 @@ def newAPIHadoopFile(self, path, inputFormatClass, keyClass, valueClass, keyConv @param valueConverter: (None by default) @param conf: Hadoop configuration, passed in as a dict (None by default) + @param batchSize: The number of Python objects represented as a single + Java object. (default sc._default_batch_size_for_serialized_input) """ jconf = self._dictToJavaMap(conf) + batchSize = max(1, batchSize or self._default_batch_size_for_serialized_input) + ser = BatchedSerializer(PickleSerializer()) if (batchSize > 1) else PickleSerializer() jrdd = self._jvm.PythonRDD.newAPIHadoopFile(self._jsc, path, inputFormatClass, keyClass, - valueClass, keyConverter, valueConverter, jconf) - return RDD(jrdd, self, PickleSerializer()) + valueClass, keyConverter, valueConverter, jconf, batchSize) + return RDD(jrdd, self, ser) def newAPIHadoopRDD(self, inputFormatClass, keyClass, valueClass, keyConverter=None, - valueConverter=None, conf=None): + valueConverter=None, conf=None, batchSize=None): """ Read a 'new API' Hadoop InputFormat with arbitrary key and value class, from an arbitrary Hadoop configuration, which is passed in as a Python dict. @@ -449,14 +458,18 @@ def newAPIHadoopRDD(self, inputFormatClass, keyClass, valueClass, keyConverter=N @param valueConverter: (None by default) @param conf: Hadoop configuration, passed in as a dict (None by default) + @param batchSize: The number of Python objects represented as a single + Java object. (default sc._default_batch_size_for_serialized_input) """ jconf = self._dictToJavaMap(conf) + batchSize = max(1, batchSize or self._default_batch_size_for_serialized_input) + ser = BatchedSerializer(PickleSerializer()) if (batchSize > 1) else PickleSerializer() jrdd = self._jvm.PythonRDD.newAPIHadoopRDD(self._jsc, inputFormatClass, keyClass, - valueClass, keyConverter, valueConverter, jconf) - return RDD(jrdd, self, PickleSerializer()) + valueClass, keyConverter, valueConverter, jconf, batchSize) + return RDD(jrdd, self, ser) def hadoopFile(self, path, inputFormatClass, keyClass, valueClass, keyConverter=None, - valueConverter=None, conf=None): + valueConverter=None, conf=None, batchSize=None): """ Read an 'old' Hadoop InputFormat with arbitrary key and value class from HDFS, a local file system (available on all nodes), or any Hadoop-supported file system URI. @@ -476,14 +489,18 @@ def hadoopFile(self, path, inputFormatClass, keyClass, valueClass, keyConverter= @param valueConverter: (None by default) @param conf: Hadoop configuration, passed in as a dict (None by default) + @param batchSize: The number of Python objects represented as a single + Java object. (default sc._default_batch_size_for_serialized_input) """ jconf = self._dictToJavaMap(conf) + batchSize = max(1, batchSize or self._default_batch_size_for_serialized_input) + ser = BatchedSerializer(PickleSerializer()) if (batchSize > 1) else PickleSerializer() jrdd = self._jvm.PythonRDD.hadoopFile(self._jsc, path, inputFormatClass, keyClass, - valueClass, keyConverter, valueConverter, jconf) - return RDD(jrdd, self, PickleSerializer()) + valueClass, keyConverter, valueConverter, jconf, batchSize) + return RDD(jrdd, self, ser) def hadoopRDD(self, inputFormatClass, keyClass, valueClass, keyConverter=None, - valueConverter=None, conf=None): + valueConverter=None, conf=None, batchSize=None): """ Read an 'old' Hadoop InputFormat with arbitrary key and value class, from an arbitrary Hadoop configuration, which is passed in as a Python dict. @@ -500,11 +517,15 @@ def hadoopRDD(self, inputFormatClass, keyClass, valueClass, keyConverter=None, @param valueConverter: (None by default) @param conf: Hadoop configuration, passed in as a dict (None by default) + @param batchSize: The number of Python objects represented as a single + Java object. (default sc._default_batch_size_for_serialized_input) """ jconf = self._dictToJavaMap(conf) + batchSize = max(1, batchSize or self._default_batch_size_for_serialized_input) + ser = BatchedSerializer(PickleSerializer()) if (batchSize > 1) else PickleSerializer() jrdd = self._jvm.PythonRDD.hadoopRDD(self._jsc, inputFormatClass, keyClass, valueClass, - keyConverter, valueConverter, jconf) - return RDD(jrdd, self, PickleSerializer()) + keyConverter, valueConverter, jconf, batchSize) + return RDD(jrdd, self, ser) def _checkpointFile(self, name, input_deserializer): jrdd = self._jsc.checkpointFile(name) diff --git a/python/pyspark/daemon.py b/python/pyspark/daemon.py index 8a5873ded2b8b..b00da833d06f1 100644 --- a/python/pyspark/daemon.py +++ b/python/pyspark/daemon.py @@ -15,64 +15,39 @@ # limitations under the License. # +import numbers import os import signal +import select import socket import sys import traceback -import multiprocessing -from ctypes import c_bool from errno import EINTR, ECHILD from socket import AF_INET, SOCK_STREAM, SOMAXCONN from signal import SIGHUP, SIGTERM, SIGCHLD, SIG_DFL, SIG_IGN from pyspark.worker import main as worker_main -from pyspark.serializers import write_int - -try: - POOLSIZE = multiprocessing.cpu_count() -except NotImplementedError: - POOLSIZE = 4 - -exit_flag = multiprocessing.Value(c_bool, False) - - -def should_exit(): - global exit_flag - return exit_flag.value +from pyspark.serializers import read_int, write_int def compute_real_exit_code(exit_code): # SystemExit's code can be integer or string, but os._exit only accepts integers - import numbers if isinstance(exit_code, numbers.Integral): return exit_code else: return 1 -def worker(listen_sock): +def worker(sock): + """ + Called by a worker process after the fork(). + """ # Redirect stdout to stderr os.dup2(2, 1) sys.stdout = sys.stderr # The sys.stdout object is different from file descriptor 1 - # Manager sends SIGHUP to request termination of workers in the pool - def handle_sighup(*args): - assert should_exit() - signal.signal(SIGHUP, handle_sighup) - - # Cleanup zombie children - def handle_sigchld(*args): - pid = status = None - try: - while (pid, status) != (0, 0): - pid, status = os.waitpid(0, os.WNOHANG) - except EnvironmentError as err: - if err.errno == EINTR: - # retry - handle_sigchld() - elif err.errno != ECHILD: - raise - signal.signal(SIGCHLD, handle_sigchld) + signal.signal(SIGHUP, SIG_DFL) + signal.signal(SIGCHLD, SIG_DFL) + signal.signal(SIGTERM, SIG_DFL) # Blocks until the socket is closed by draining the input stream # until it raises an exception or returns EOF. @@ -85,55 +60,24 @@ def waitSocketClose(sock): except: pass - # Handle clients - while not should_exit(): - # Wait until a client arrives or we have to exit - sock = None - while not should_exit() and sock is None: - try: - sock, addr = listen_sock.accept() - except EnvironmentError as err: - if err.errno != EINTR: - raise - - if sock is not None: - # Fork a child to handle the client. - # The client is handled in the child so that the manager - # never receives SIGCHLD unless a worker crashes. - if os.fork() == 0: - # Leave the worker pool - signal.signal(SIGHUP, SIG_DFL) - signal.signal(SIGCHLD, SIG_DFL) - listen_sock.close() - # Read the socket using fdopen instead of socket.makefile() because the latter - # seems to be very slow; note that we need to dup() the file descriptor because - # otherwise writes also cause a seek that makes us miss data on the read side. - infile = os.fdopen(os.dup(sock.fileno()), "a+", 65536) - outfile = os.fdopen(os.dup(sock.fileno()), "a+", 65536) - exit_code = 0 - try: - worker_main(infile, outfile) - except SystemExit as exc: - exit_code = exc.code - finally: - outfile.flush() - # The Scala side will close the socket upon task completion. - waitSocketClose(sock) - os._exit(compute_real_exit_code(exit_code)) - else: - sock.close() - - -def launch_worker(listen_sock): - if os.fork() == 0: - try: - worker(listen_sock) - except Exception as err: - traceback.print_exc() - os._exit(1) - else: - assert should_exit() - os._exit(0) + # Read the socket using fdopen instead of socket.makefile() because the latter + # seems to be very slow; note that we need to dup() the file descriptor because + # otherwise writes also cause a seek that makes us miss data on the read side. + infile = os.fdopen(os.dup(sock.fileno()), "a+", 65536) + outfile = os.fdopen(os.dup(sock.fileno()), "a+", 65536) + exit_code = 0 + try: + # Acknowledge that the fork was successful + write_int(os.getpid(), outfile) + outfile.flush() + worker_main(infile, outfile) + except SystemExit as exc: + exit_code = exc.code + finally: + outfile.flush() + # The Scala side will close the socket upon task completion. + waitSocketClose(sock) + os._exit(compute_real_exit_code(exit_code)) def manager(): @@ -143,29 +87,28 @@ def manager(): # Create a listening socket on the AF_INET loopback interface listen_sock = socket.socket(AF_INET, SOCK_STREAM) listen_sock.bind(('127.0.0.1', 0)) - listen_sock.listen(max(1024, 2 * POOLSIZE, SOMAXCONN)) + listen_sock.listen(max(1024, SOMAXCONN)) listen_host, listen_port = listen_sock.getsockname() write_int(listen_port, sys.stdout) - # Launch initial worker pool - for idx in range(POOLSIZE): - launch_worker(listen_sock) - listen_sock.close() - - def shutdown(): - global exit_flag - exit_flag.value = True + def shutdown(code): + signal.signal(SIGTERM, SIG_DFL) + # Send SIGHUP to notify workers of shutdown + os.kill(0, SIGHUP) + exit(code) - # Gracefully exit on SIGTERM, don't die on SIGHUP - signal.signal(SIGTERM, lambda signum, frame: shutdown()) - signal.signal(SIGHUP, SIG_IGN) + def handle_sigterm(*args): + shutdown(1) + signal.signal(SIGTERM, handle_sigterm) # Gracefully exit on SIGTERM + signal.signal(SIGHUP, SIG_IGN) # Don't die on SIGHUP # Cleanup zombie children def handle_sigchld(*args): try: pid, status = os.waitpid(0, os.WNOHANG) - if status != 0 and not should_exit(): - raise RuntimeError("worker crashed: %s, %s" % (pid, status)) + if status != 0: + msg = "worker %s crashed abruptly with exit status %s" % (pid, status) + print >> sys.stderr, msg except EnvironmentError as err: if err.errno not in (ECHILD, EINTR): raise @@ -174,20 +117,52 @@ def handle_sigchld(*args): # Initialization complete sys.stdout.close() try: - while not should_exit(): + while True: try: - # Spark tells us to exit by closing stdin - if os.read(0, 512) == '': - shutdown() - except EnvironmentError as err: - if err.errno != EINTR: - shutdown() + ready_fds = select.select([0, listen_sock], [], [])[0] + except select.error as ex: + if ex[0] == EINTR: + continue + else: raise + if 0 in ready_fds: + try: + worker_pid = read_int(sys.stdin) + except EOFError: + # Spark told us to exit by closing stdin + shutdown(0) + try: + os.kill(worker_pid, signal.SIGKILL) + except OSError: + pass # process already died + + + if listen_sock in ready_fds: + sock, addr = listen_sock.accept() + # Launch a worker process + try: + pid = os.fork() + if pid == 0: + listen_sock.close() + try: + worker(sock) + except: + traceback.print_exc() + os._exit(1) + else: + os._exit(0) + else: + sock.close() + + except OSError as e: + print >> sys.stderr, "Daemon failed to fork PySpark worker: %s" % e + outfile = os.fdopen(os.dup(sock.fileno()), "a+", 65536) + write_int(-1, outfile) # Signal that the fork failed + outfile.flush() + outfile.close() + sock.close() finally: - signal.signal(SIGTERM, SIG_DFL) - exit_flag.value = True - # Send SIGHUP to notify workers of shutdown - os.kill(0, SIGHUP) + shutdown(1) if __name__ == '__main__': diff --git a/python/pyspark/mllib/_common.py b/python/pyspark/mllib/_common.py index 8e3ad6b783b6c..9c1565affbdac 100644 --- a/python/pyspark/mllib/_common.py +++ b/python/pyspark/mllib/_common.py @@ -101,7 +101,7 @@ def _serialize_double(d): """ Serialize a double (float or numpy.float64) into a mutually understood format. """ - if type(d) == float or type(d) == float64: + if type(d) == float or type(d) == float64 or type(d) == int or type(d) == long: d = float64(d) ba = bytearray(8) _copyto(d, buffer=ba, offset=0, shape=[1], dtype=float64) @@ -176,6 +176,10 @@ def _deserialize_double(ba, offset=0): True >>> _deserialize_double(_serialize_double(float64(0.0))) == 0.0 True + >>> _deserialize_double(_serialize_double(1)) == 1.0 + True + >>> _deserialize_double(_serialize_double(1L)) == 1.0 + True >>> x = sys.float_info.max >>> _deserialize_double(_serialize_double(sys.float_info.max)) == x True @@ -339,22 +343,35 @@ def _copyto(array, buffer, offset, shape, dtype): temp_array[...] = array -def _get_unmangled_rdd(data, serializer): +def _get_unmangled_rdd(data, serializer, cache=True): + """ + :param cache: If True, the serialized RDD is cached. (default = True) + WARNING: Users should unpersist() this later! + """ dataBytes = data.map(serializer) dataBytes._bypass_serializer = True - dataBytes.cache() # TODO: users should unpersist() this later! + if cache: + dataBytes.cache() return dataBytes -# Map a pickled Python RDD of Python dense or sparse vectors to a Java RDD of -# _serialized_double_vectors -def _get_unmangled_double_vector_rdd(data): - return _get_unmangled_rdd(data, _serialize_double_vector) +def _get_unmangled_double_vector_rdd(data, cache=True): + """ + Map a pickled Python RDD of Python dense or sparse vectors to a Java RDD of + _serialized_double_vectors. + :param cache: If True, the serialized RDD is cached. (default = True) + WARNING: Users should unpersist() this later! + """ + return _get_unmangled_rdd(data, _serialize_double_vector, cache) -# Map a pickled Python RDD of LabeledPoint to a Java RDD of _serialized_labeled_points -def _get_unmangled_labeled_point_rdd(data): - return _get_unmangled_rdd(data, _serialize_labeled_point) +def _get_unmangled_labeled_point_rdd(data, cache=True): + """ + Map a pickled Python RDD of LabeledPoint to a Java RDD of _serialized_labeled_points. + :param cache: If True, the serialized RDD is cached. (default = True) + WARNING: Users should unpersist() this later! + """ + return _get_unmangled_rdd(data, _serialize_labeled_point, cache) # Common functions for dealing with and training linear models @@ -376,7 +393,7 @@ def _linear_predictor_typecheck(x, coeffs): if x.size != coeffs.shape[0]: raise RuntimeError("Got sparse vector of size %d; wanted %d" % ( x.size, coeffs.shape[0])) - elif (type(x) == RDD): + elif isinstance(x, RDD): raise RuntimeError("Bulk predict not yet supported.") else: raise TypeError("Argument of type " + type(x).__name__ + " unsupported") diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py index 9e28dfbb9145d..5ec1a8084d269 100644 --- a/python/pyspark/mllib/classification.py +++ b/python/pyspark/mllib/classification.py @@ -66,17 +66,43 @@ def predict(self, x): if margin > 0: prob = 1 / (1 + exp(-margin)) else: - prob = 1 - 1 / (1 + exp(margin)) + exp_margin = exp(margin) + prob = exp_margin / (1 + exp_margin) return 1 if prob > 0.5 else 0 class LogisticRegressionWithSGD(object): @classmethod - def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, initialWeights=None): - """Train a logistic regression model on the given data.""" + def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, + initialWeights=None, regParam=1.0, regType=None, intercept=False): + """ + Train a logistic regression model on the given data. + + @param data: The training data. + @param iterations: The number of iterations (default: 100). + @param step: The step parameter used in SGD + (default: 1.0). + @param miniBatchFraction: Fraction of data to be used for each SGD + iteration. + @param initialWeights: The initial weights (default: None). + @param regParam: The regularizer parameter (default: 1.0). + @param regType: The type of regularizer used for training + our model. + Allowed values: "l1" for using L1Updater, + "l2" for using + SquaredL2Updater, + "none" for no regularizer. + (default: "none") + @param intercept: Boolean parameter which indicates the use + or not of the augmented representation for + training data (i.e. whether bias features + are activated or not). + """ sc = data.context + if regType is None: + regType = "none" train_func = lambda d, i: sc._jvm.PythonMLLibAPI().trainLogisticRegressionModelWithSGD( - d._jrdd, iterations, step, miniBatchFraction, i) + d._jrdd, iterations, step, miniBatchFraction, i, regParam, regType, intercept) return _regression_train_wrapper(sc, train_func, LogisticRegressionModel, data, initialWeights) @@ -114,11 +140,35 @@ def predict(self, x): class SVMWithSGD(object): @classmethod def train(cls, data, iterations=100, step=1.0, regParam=1.0, - miniBatchFraction=1.0, initialWeights=None): - """Train a support vector machine on the given data.""" + miniBatchFraction=1.0, initialWeights=None, regType=None, intercept=False): + """ + Train a support vector machine on the given data. + + @param data: The training data. + @param iterations: The number of iterations (default: 100). + @param step: The step parameter used in SGD + (default: 1.0). + @param regParam: The regularizer parameter (default: 1.0). + @param miniBatchFraction: Fraction of data to be used for each SGD + iteration. + @param initialWeights: The initial weights (default: None). + @param regType: The type of regularizer used for training + our model. + Allowed values: "l1" for using L1Updater, + "l2" for using + SquaredL2Updater, + "none" for no regularizer. + (default: "none") + @param intercept: Boolean parameter which indicates the use + or not of the augmented representation for + training data (i.e. whether bias features + are activated or not). + """ sc = data.context + if regType is None: + regType = "none" train_func = lambda d, i: sc._jvm.PythonMLLibAPI().trainSVMModelWithSGD( - d._jrdd, iterations, step, regParam, miniBatchFraction, i) + d._jrdd, iterations, step, regParam, miniBatchFraction, i, regType, intercept) return _regression_train_wrapper(sc, train_func, SVMModel, data, initialWeights) diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg.py index 71f4ad1a8d44e..54720c2324ca6 100644 --- a/python/pyspark/mllib/linalg.py +++ b/python/pyspark/mllib/linalg.py @@ -255,4 +255,8 @@ def _test(): exit(-1) if __name__ == "__main__": + # remove current path from list of search paths to avoid importing mllib.random + # for C{import random}, which is done in an external dependency of pyspark during doctests. + import sys + sys.path.pop(0) _test() diff --git a/python/pyspark/mllib/random.py b/python/pyspark/mllib/random.py new file mode 100644 index 0000000000000..36e710dbae7a8 --- /dev/null +++ b/python/pyspark/mllib/random.py @@ -0,0 +1,182 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Python package for random data generation. +""" + + +from pyspark.rdd import RDD +from pyspark.mllib._common import _deserialize_double, _deserialize_double_vector +from pyspark.serializers import NoOpSerializer + +class RandomRDDGenerators: + """ + Generator methods for creating RDDs comprised of i.i.d samples from + some distribution. + """ + + @staticmethod + def uniformRDD(sc, size, numPartitions=None, seed=None): + """ + Generates an RDD comprised of i.i.d. samples from the + uniform distribution on [0.0, 1.0]. + + To transform the distribution in the generated RDD from U[0.0, 1.0] + to U[a, b], use + C{RandomRDDGenerators.uniformRDD(sc, n, p, seed)\ + .map(lambda v: a + (b - a) * v)} + + >>> x = RandomRDDGenerators.uniformRDD(sc, 100).collect() + >>> len(x) + 100 + >>> max(x) <= 1.0 and min(x) >= 0.0 + True + >>> RandomRDDGenerators.uniformRDD(sc, 100, 4).getNumPartitions() + 4 + >>> parts = RandomRDDGenerators.uniformRDD(sc, 100, seed=4).getNumPartitions() + >>> parts == sc.defaultParallelism + True + """ + jrdd = sc._jvm.PythonMLLibAPI().uniformRDD(sc._jsc, size, numPartitions, seed) + uniform = RDD(jrdd, sc, NoOpSerializer()) + return uniform.map(lambda bytes: _deserialize_double(bytearray(bytes))) + + @staticmethod + def normalRDD(sc, size, numPartitions=None, seed=None): + """ + Generates an RDD comprised of i.i.d samples from the standard normal + distribution. + + To transform the distribution in the generated RDD from standard normal + to some other normal N(mean, sigma), use + C{RandomRDDGenerators.normal(sc, n, p, seed)\ + .map(lambda v: mean + sigma * v)} + + >>> x = RandomRDDGenerators.normalRDD(sc, 1000, seed=1L) + >>> stats = x.stats() + >>> stats.count() + 1000L + >>> abs(stats.mean() - 0.0) < 0.1 + True + >>> abs(stats.stdev() - 1.0) < 0.1 + True + """ + jrdd = sc._jvm.PythonMLLibAPI().normalRDD(sc._jsc, size, numPartitions, seed) + normal = RDD(jrdd, sc, NoOpSerializer()) + return normal.map(lambda bytes: _deserialize_double(bytearray(bytes))) + + @staticmethod + def poissonRDD(sc, mean, size, numPartitions=None, seed=None): + """ + Generates an RDD comprised of i.i.d samples from the Poisson + distribution with the input mean. + + >>> mean = 100.0 + >>> x = RandomRDDGenerators.poissonRDD(sc, mean, 1000, seed=1L) + >>> stats = x.stats() + >>> stats.count() + 1000L + >>> abs(stats.mean() - mean) < 0.5 + True + >>> from math import sqrt + >>> abs(stats.stdev() - sqrt(mean)) < 0.5 + True + """ + jrdd = sc._jvm.PythonMLLibAPI().poissonRDD(sc._jsc, mean, size, numPartitions, seed) + poisson = RDD(jrdd, sc, NoOpSerializer()) + return poisson.map(lambda bytes: _deserialize_double(bytearray(bytes))) + + @staticmethod + def uniformVectorRDD(sc, numRows, numCols, numPartitions=None, seed=None): + """ + Generates an RDD comprised of vectors containing i.i.d samples drawn + from the uniform distribution on [0.0 1.0]. + + >>> import numpy as np + >>> mat = np.matrix(RandomRDDGenerators.uniformVectorRDD(sc, 10, 10).collect()) + >>> mat.shape + (10, 10) + >>> mat.max() <= 1.0 and mat.min() >= 0.0 + True + >>> RandomRDDGenerators.uniformVectorRDD(sc, 10, 10, 4).getNumPartitions() + 4 + """ + jrdd = sc._jvm.PythonMLLibAPI() \ + .uniformVectorRDD(sc._jsc, numRows, numCols, numPartitions, seed) + uniform = RDD(jrdd, sc, NoOpSerializer()) + return uniform.map(lambda bytes: _deserialize_double_vector(bytearray(bytes))) + + @staticmethod + def normalVectorRDD(sc, numRows, numCols, numPartitions=None, seed=None): + """ + Generates an RDD comprised of vectors containing i.i.d samples drawn + from the standard normal distribution. + + >>> import numpy as np + >>> mat = np.matrix(RandomRDDGenerators.normalVectorRDD(sc, 100, 100, seed=1L).collect()) + >>> mat.shape + (100, 100) + >>> abs(mat.mean() - 0.0) < 0.1 + True + >>> abs(mat.std() - 1.0) < 0.1 + True + """ + jrdd = sc._jvm.PythonMLLibAPI() \ + .normalVectorRDD(sc._jsc, numRows, numCols, numPartitions, seed) + normal = RDD(jrdd, sc, NoOpSerializer()) + return normal.map(lambda bytes: _deserialize_double_vector(bytearray(bytes))) + + @staticmethod + def poissonVectorRDD(sc, mean, numRows, numCols, numPartitions=None, seed=None): + """ + Generates an RDD comprised of vectors containing i.i.d samples drawn + from the Poisson distribution with the input mean. + + >>> import numpy as np + >>> mean = 100.0 + >>> rdd = RandomRDDGenerators.poissonVectorRDD(sc, mean, 100, 100, seed=1L) + >>> mat = np.mat(rdd.collect()) + >>> mat.shape + (100, 100) + >>> abs(mat.mean() - mean) < 0.5 + True + >>> from math import sqrt + >>> abs(mat.std() - sqrt(mean)) < 0.5 + True + """ + jrdd = sc._jvm.PythonMLLibAPI() \ + .poissonVectorRDD(sc._jsc, mean, numRows, numCols, numPartitions, seed) + poisson = RDD(jrdd, sc, NoOpSerializer()) + return poisson.map(lambda bytes: _deserialize_double_vector(bytearray(bytes))) + + +def _test(): + import doctest + from pyspark.context import SparkContext + globs = globals().copy() + # The small batch size here ensures that we see multiple batches, + # even in these small test examples: + globs['sc'] = SparkContext('local[2]', 'PythonTest', batchSize=2) + (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) + globs['sc'].stop() + if failure_count: + exit(-1) + + +if __name__ == "__main__": + _test() diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py index b84bc531dec8c..041b119269427 100644 --- a/python/pyspark/mllib/regression.py +++ b/python/pyspark/mllib/regression.py @@ -112,12 +112,36 @@ class LinearRegressionModel(LinearRegressionModelBase): class LinearRegressionWithSGD(object): @classmethod - def train(cls, data, iterations=100, step=1.0, - miniBatchFraction=1.0, initialWeights=None): - """Train a linear regression model on the given data.""" + def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, + initialWeights=None, regParam=1.0, regType=None, intercept=False): + """ + Train a linear regression model on the given data. + + @param data: The training data. + @param iterations: The number of iterations (default: 100). + @param step: The step parameter used in SGD + (default: 1.0). + @param miniBatchFraction: Fraction of data to be used for each SGD + iteration. + @param initialWeights: The initial weights (default: None). + @param regParam: The regularizer parameter (default: 1.0). + @param regType: The type of regularizer used for training + our model. + Allowed values: "l1" for using L1Updater, + "l2" for using + SquaredL2Updater, + "none" for no regularizer. + (default: "none") + @param intercept: Boolean parameter which indicates the use + or not of the augmented representation for + training data (i.e. whether bias features + are activated or not). + """ sc = data.context + if regType is None: + regType = "none" train_f = lambda d, i: sc._jvm.PythonMLLibAPI().trainLinearRegressionModelWithSGD( - d._jrdd, iterations, step, miniBatchFraction, i) + d._jrdd, iterations, step, miniBatchFraction, i, regParam, regType, intercept) return _regression_train_wrapper(sc, train_f, LinearRegressionModel, data, initialWeights) diff --git a/python/pyspark/mllib/stat.py b/python/pyspark/mllib/stat.py new file mode 100644 index 0000000000000..0a08a562d1f1f --- /dev/null +++ b/python/pyspark/mllib/stat.py @@ -0,0 +1,104 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Python package for statistical functions in MLlib. +""" + +from pyspark.mllib._common import \ + _get_unmangled_double_vector_rdd, _get_unmangled_rdd, \ + _serialize_double, _serialize_double_vector, \ + _deserialize_double, _deserialize_double_matrix + +class Statistics(object): + + @staticmethod + def corr(x, y=None, method=None): + """ + Compute the correlation (matrix) for the input RDD(s) using the + specified method. + Methods currently supported: I{pearson (default), spearman}. + + If a single RDD of Vectors is passed in, a correlation matrix + comparing the columns in the input RDD is returned. Use C{method=} + to specify the method to be used for single RDD inout. + If two RDDs of floats are passed in, a single float is returned. + + >>> x = sc.parallelize([1.0, 0.0, -2.0], 2) + >>> y = sc.parallelize([4.0, 5.0, 3.0], 2) + >>> zeros = sc.parallelize([0.0, 0.0, 0.0], 2) + >>> abs(Statistics.corr(x, y) - 0.6546537) < 1e-7 + True + >>> Statistics.corr(x, y) == Statistics.corr(x, y, "pearson") + True + >>> Statistics.corr(x, y, "spearman") + 0.5 + >>> from math import isnan + >>> isnan(Statistics.corr(x, zeros)) + True + >>> from linalg import Vectors + >>> rdd = sc.parallelize([Vectors.dense([1, 0, 0, -2]), Vectors.dense([4, 5, 0, 3]), + ... Vectors.dense([6, 7, 0, 8]), Vectors.dense([9, 0, 0, 1])]) + >>> Statistics.corr(rdd) + array([[ 1. , 0.05564149, nan, 0.40047142], + [ 0.05564149, 1. , nan, 0.91359586], + [ nan, nan, 1. , nan], + [ 0.40047142, 0.91359586, nan, 1. ]]) + >>> Statistics.corr(rdd, method="spearman") + array([[ 1. , 0.10540926, nan, 0.4 ], + [ 0.10540926, 1. , nan, 0.9486833 ], + [ nan, nan, 1. , nan], + [ 0.4 , 0.9486833 , nan, 1. ]]) + >>> try: + ... Statistics.corr(rdd, "spearman") + ... print "Method name as second argument without 'method=' shouldn't be allowed." + ... except TypeError: + ... pass + """ + sc = x.ctx + # Check inputs to determine whether a single value or a matrix is needed for output. + # Since it's legal for users to use the method name as the second argument, we need to + # check if y is used to specify the method name instead. + if type(y) == str: + raise TypeError("Use 'method=' to specify method name.") + if not y: + try: + Xser = _get_unmangled_double_vector_rdd(x) + except TypeError: + raise TypeError("corr called on a single RDD not consisted of Vectors.") + resultMat = sc._jvm.PythonMLLibAPI().corr(Xser._jrdd, method) + return _deserialize_double_matrix(resultMat) + else: + xSer = _get_unmangled_rdd(x, _serialize_double) + ySer = _get_unmangled_rdd(y, _serialize_double) + result = sc._jvm.PythonMLLibAPI().corr(xSer._jrdd, ySer._jrdd, method) + return result + + +def _test(): + import doctest + from pyspark import SparkContext + globs = globals().copy() + globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) + (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) + globs['sc'].stop() + if failure_count: + exit(-1) + + +if __name__ == "__main__": + _test() diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 37ccf1d590743..9d1e5be637a9a 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -100,6 +100,7 @@ def test_clustering(self): def test_classification(self): from pyspark.mllib.classification import LogisticRegressionWithSGD, SVMWithSGD, NaiveBayes + from pyspark.mllib.tree import DecisionTree data = [ LabeledPoint(0.0, [1, 0, 0]), LabeledPoint(1.0, [0, 1, 1]), @@ -127,9 +128,19 @@ def test_classification(self): self.assertTrue(nb_model.predict(features[2]) <= 0) self.assertTrue(nb_model.predict(features[3]) > 0) + categoricalFeaturesInfo = {0: 3} # feature 0 has 3 categories + dt_model = \ + DecisionTree.trainClassifier(rdd, numClasses=2, + categoricalFeaturesInfo=categoricalFeaturesInfo) + self.assertTrue(dt_model.predict(features[0]) <= 0) + self.assertTrue(dt_model.predict(features[1]) > 0) + self.assertTrue(dt_model.predict(features[2]) <= 0) + self.assertTrue(dt_model.predict(features[3]) > 0) + def test_regression(self): from pyspark.mllib.regression import LinearRegressionWithSGD, LassoWithSGD, \ RidgeRegressionWithSGD + from pyspark.mllib.tree import DecisionTree data = [ LabeledPoint(-1.0, [0, -1]), LabeledPoint(1.0, [0, 1]), @@ -157,6 +168,14 @@ def test_regression(self): self.assertTrue(rr_model.predict(features[2]) <= 0) self.assertTrue(rr_model.predict(features[3]) > 0) + categoricalFeaturesInfo = {0: 2} # feature 0 has 2 categories + dt_model = \ + DecisionTree.trainRegressor(rdd, categoricalFeaturesInfo=categoricalFeaturesInfo) + self.assertTrue(dt_model.predict(features[0]) <= 0) + self.assertTrue(dt_model.predict(features[1]) > 0) + self.assertTrue(dt_model.predict(features[2]) <= 0) + self.assertTrue(dt_model.predict(features[3]) > 0) + @unittest.skipIf(not _have_scipy, "SciPy not installed") class SciPyTests(PySparkTestCase): @@ -229,6 +248,7 @@ def test_clustering(self): def test_classification(self): from pyspark.mllib.classification import LogisticRegressionWithSGD, SVMWithSGD, NaiveBayes + from pyspark.mllib.tree import DecisionTree data = [ LabeledPoint(0.0, self.scipy_matrix(2, {0: 1.0})), LabeledPoint(1.0, self.scipy_matrix(2, {1: 1.0})), @@ -256,9 +276,18 @@ def test_classification(self): self.assertTrue(nb_model.predict(features[2]) <= 0) self.assertTrue(nb_model.predict(features[3]) > 0) + categoricalFeaturesInfo = {0: 3} # feature 0 has 3 categories + dt_model = DecisionTree.trainClassifier(rdd, numClasses=2, + categoricalFeaturesInfo=categoricalFeaturesInfo) + self.assertTrue(dt_model.predict(features[0]) <= 0) + self.assertTrue(dt_model.predict(features[1]) > 0) + self.assertTrue(dt_model.predict(features[2]) <= 0) + self.assertTrue(dt_model.predict(features[3]) > 0) + def test_regression(self): from pyspark.mllib.regression import LinearRegressionWithSGD, LassoWithSGD, \ RidgeRegressionWithSGD + from pyspark.mllib.tree import DecisionTree data = [ LabeledPoint(-1.0, self.scipy_matrix(2, {1: -1.0})), LabeledPoint(1.0, self.scipy_matrix(2, {1: 1.0})), @@ -286,6 +315,13 @@ def test_regression(self): self.assertTrue(rr_model.predict(features[2]) <= 0) self.assertTrue(rr_model.predict(features[3]) > 0) + categoricalFeaturesInfo = {0: 2} # feature 0 has 2 categories + dt_model = DecisionTree.trainRegressor(rdd, categoricalFeaturesInfo=categoricalFeaturesInfo) + self.assertTrue(dt_model.predict(features[0]) <= 0) + self.assertTrue(dt_model.predict(features[1]) > 0) + self.assertTrue(dt_model.predict(features[2]) <= 0) + self.assertTrue(dt_model.predict(features[3]) > 0) + if __name__ == "__main__": if not _have_scipy: diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py new file mode 100644 index 0000000000000..1e0006df75ac6 --- /dev/null +++ b/python/pyspark/mllib/tree.py @@ -0,0 +1,225 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from py4j.java_collections import MapConverter + +from pyspark import SparkContext, RDD +from pyspark.mllib._common import \ + _get_unmangled_rdd, _get_unmangled_double_vector_rdd, _serialize_double_vector, \ + _deserialize_labeled_point, _get_unmangled_labeled_point_rdd, \ + _deserialize_double +from pyspark.mllib.regression import LabeledPoint +from pyspark.serializers import NoOpSerializer + +class DecisionTreeModel(object): + """ + A decision tree model for classification or regression. + + EXPERIMENTAL: This is an experimental API. + It will probably be modified for Spark v1.2. + """ + + def __init__(self, sc, java_model): + """ + :param sc: Spark context + :param java_model: Handle to Java model object + """ + self._sc = sc + self._java_model = java_model + + def __del__(self): + self._sc._gateway.detach(self._java_model) + + def predict(self, x): + """ + Predict the label of one or more examples. + :param x: Data point (feature vector), + or an RDD of data points (feature vectors). + """ + pythonAPI = self._sc._jvm.PythonMLLibAPI() + if isinstance(x, RDD): + # Bulk prediction + if x.count() == 0: + return self._sc.parallelize([]) + dataBytes = _get_unmangled_double_vector_rdd(x, cache=False) + jSerializedPreds = \ + pythonAPI.predictDecisionTreeModel(self._java_model, + dataBytes._jrdd) + serializedPreds = RDD(jSerializedPreds, self._sc, NoOpSerializer()) + return serializedPreds.map(lambda bytes: _deserialize_double(bytearray(bytes))) + else: + # Assume x is a single data point. + x_ = _serialize_double_vector(x) + return pythonAPI.predictDecisionTreeModel(self._java_model, x_) + + def numNodes(self): + return self._java_model.numNodes() + + def depth(self): + return self._java_model.depth() + + def __str__(self): + return self._java_model.toString() + + +class DecisionTree(object): + """ + Learning algorithm for a decision tree model + for classification or regression. + + EXPERIMENTAL: This is an experimental API. + It will probably be modified for Spark v1.2. + + Example usage: + >>> from numpy import array, ndarray + >>> from pyspark.mllib.regression import LabeledPoint + >>> from pyspark.mllib.tree import DecisionTree + >>> from pyspark.mllib.linalg import SparseVector + >>> + >>> data = [ + ... LabeledPoint(0.0, [0.0]), + ... LabeledPoint(1.0, [1.0]), + ... LabeledPoint(1.0, [2.0]), + ... LabeledPoint(1.0, [3.0]) + ... ] + >>> + >>> model = DecisionTree.trainClassifier(sc.parallelize(data), numClasses=2) + >>> print(model) + DecisionTreeModel classifier + If (feature 0 <= 0.5) + Predict: 0.0 + Else (feature 0 > 0.5) + Predict: 1.0 + + >>> model.predict(array([1.0])) > 0 + True + >>> model.predict(array([0.0])) == 0 + True + >>> sparse_data = [ + ... LabeledPoint(0.0, SparseVector(2, {0: 0.0})), + ... LabeledPoint(1.0, SparseVector(2, {1: 1.0})), + ... LabeledPoint(0.0, SparseVector(2, {0: 0.0})), + ... LabeledPoint(1.0, SparseVector(2, {1: 2.0})) + ... ] + >>> + >>> model = DecisionTree.trainRegressor(sc.parallelize(sparse_data)) + >>> model.predict(array([0.0, 1.0])) == 1 + True + >>> model.predict(array([0.0, 0.0])) == 0 + True + >>> model.predict(SparseVector(2, {1: 1.0})) == 1 + True + >>> model.predict(SparseVector(2, {1: 0.0})) == 0 + True + """ + + @staticmethod + def trainClassifier(data, numClasses, categoricalFeaturesInfo={}, + impurity="gini", maxDepth=4, maxBins=100): + """ + Train a DecisionTreeModel for classification. + + :param data: Training data: RDD of LabeledPoint. + Labels are integers {0,1,...,numClasses}. + :param numClasses: Number of classes for classification. + :param categoricalFeaturesInfo: Map from categorical feature index + to number of categories. + Any feature not in this map + is treated as continuous. + :param impurity: Supported values: "entropy" or "gini" + :param maxDepth: Max depth of tree. + E.g., depth 0 means 1 leaf node. + Depth 1 means 1 internal node + 2 leaf nodes. + :param maxBins: Number of bins used for finding splits at each node. + :return: DecisionTreeModel + """ + return DecisionTree.train(data, "classification", numClasses, + categoricalFeaturesInfo, + impurity, maxDepth, maxBins) + + @staticmethod + def trainRegressor(data, categoricalFeaturesInfo={}, + impurity="variance", maxDepth=4, maxBins=100): + """ + Train a DecisionTreeModel for regression. + + :param data: Training data: RDD of LabeledPoint. + Labels are real numbers. + :param categoricalFeaturesInfo: Map from categorical feature index + to number of categories. + Any feature not in this map + is treated as continuous. + :param impurity: Supported values: "variance" + :param maxDepth: Max depth of tree. + E.g., depth 0 means 1 leaf node. + Depth 1 means 1 internal node + 2 leaf nodes. + :param maxBins: Number of bins used for finding splits at each node. + :return: DecisionTreeModel + """ + return DecisionTree.train(data, "regression", 0, + categoricalFeaturesInfo, + impurity, maxDepth, maxBins) + + + @staticmethod + def train(data, algo, numClasses, categoricalFeaturesInfo, + impurity, maxDepth, maxBins=100): + """ + Train a DecisionTreeModel for classification or regression. + + :param data: Training data: RDD of LabeledPoint. + For classification, labels are integers + {0,1,...,numClasses}. + For regression, labels are real numbers. + :param algo: "classification" or "regression" + :param numClasses: Number of classes for classification. + :param categoricalFeaturesInfo: Map from categorical feature index + to number of categories. + Any feature not in this map + is treated as continuous. + :param impurity: For classification: "entropy" or "gini". + For regression: "variance". + :param maxDepth: Max depth of tree. + E.g., depth 0 means 1 leaf node. + Depth 1 means 1 internal node + 2 leaf nodes. + :param maxBins: Number of bins used for finding splits at each node. + :return: DecisionTreeModel + """ + sc = data.context + dataBytes = _get_unmangled_labeled_point_rdd(data) + categoricalFeaturesInfoJMap = \ + MapConverter().convert(categoricalFeaturesInfo, + sc._gateway._gateway_client) + model = sc._jvm.PythonMLLibAPI().trainDecisionTreeModel( + dataBytes._jrdd, algo, + numClasses, categoricalFeaturesInfoJMap, + impurity, maxDepth, maxBins) + dataBytes.unpersist() + return DecisionTreeModel(sc, model) + + +def _test(): + import doctest + globs = globals().copy() + globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) + (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) + globs['sc'].stop() + if failure_count: + exit(-1) + +if __name__ == "__main__": + _test() diff --git a/python/pyspark/mllib/util.py b/python/pyspark/mllib/util.py index a707a9dcd5b49..639cda6350229 100644 --- a/python/pyspark/mllib/util.py +++ b/python/pyspark/mllib/util.py @@ -16,6 +16,7 @@ # import numpy as np +import warnings from pyspark.mllib.linalg import Vectors, SparseVector from pyspark.mllib.regression import LabeledPoint @@ -31,13 +32,16 @@ class MLUtils: @staticmethod def _parse_libsvm_line(line, multiclass): + warnings.warn("deprecated", DeprecationWarning) + return _parse_libsvm_line(line) + + @staticmethod + def _parse_libsvm_line(line): """ Parses a line in LIBSVM format into (label, indices, values). """ items = line.split(None) label = float(items[0]) - if not multiclass: - label = 1.0 if label > 0.5 else 0.0 nnz = len(items) - 1 indices = np.zeros(nnz, dtype=np.int32) values = np.zeros(nnz) @@ -66,6 +70,11 @@ def _convert_labeled_point_to_libsvm(p): @staticmethod def loadLibSVMFile(sc, path, multiclass=False, numFeatures=-1, minPartitions=None): + warnings.warn("deprecated", DeprecationWarning) + return loadLibSVMFile(sc, path, numFeatures, minPartitions) + + @staticmethod + def loadLibSVMFile(sc, path, numFeatures=-1, minPartitions=None): """ Loads labeled data in the LIBSVM format into an RDD of LabeledPoint. The LIBSVM format is a text-based format used by @@ -81,13 +90,6 @@ def loadLibSVMFile(sc, path, multiclass=False, numFeatures=-1, minPartitions=Non @param sc: Spark context @param path: file or directory path in any Hadoop-supported file system URI - @param multiclass: whether the input labels contain more than - two classes. If false, any label with value - greater than 0.5 will be mapped to 1.0, or - 0.0 otherwise. So it works for both +1/-1 and - 1/0 cases. If true, the double value parsed - directly from the label string will be used - as the label value. @param numFeatures: number of features, which will be determined from the input data if a nonpositive value is given. This is useful when the dataset is @@ -105,7 +107,6 @@ def loadLibSVMFile(sc, path, multiclass=False, numFeatures=-1, minPartitions=Non >>> tempFile.write("+1 1:1.0 3:2.0 5:3.0\\n-1\\n-1 2:4.0 4:5.0 6:6.0") >>> tempFile.flush() >>> examples = MLUtils.loadLibSVMFile(sc, tempFile.name).collect() - >>> multiclass_examples = MLUtils.loadLibSVMFile(sc, tempFile.name, True).collect() >>> tempFile.close() >>> type(examples[0]) == LabeledPoint True @@ -114,20 +115,18 @@ def loadLibSVMFile(sc, path, multiclass=False, numFeatures=-1, minPartitions=Non >>> type(examples[1]) == LabeledPoint True >>> print examples[1] - (0.0,(6,[],[])) + (-1.0,(6,[],[])) >>> type(examples[2]) == LabeledPoint True >>> print examples[2] - (0.0,(6,[1,3,5],[4.0,5.0,6.0])) - >>> multiclass_examples[1].label - -1.0 + (-1.0,(6,[1,3,5],[4.0,5.0,6.0])) """ lines = sc.textFile(path, minPartitions) - parsed = lines.map(lambda l: MLUtils._parse_libsvm_line(l, multiclass)) + parsed = lines.map(lambda l: MLUtils._parse_libsvm_line(l)) if numFeatures <= 0: parsed.cache() - numFeatures = parsed.map(lambda x: 0 if x[1].size == 0 else x[1][-1]).reduce(max) + 1 + numFeatures = parsed.map(lambda x: -1 if x[1].size == 0 else x[1][-1]).reduce(max) + 1 return parsed.map(lambda x: LabeledPoint(x[0], Vectors.sparse(numFeatures, x[1], x[2]))) @staticmethod diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index b84d976114f0d..309f5a9b6038d 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -231,6 +231,13 @@ def __init__(self, jrdd, ctx, jrdd_deserializer): self._jrdd_deserializer = jrdd_deserializer self._id = jrdd.id() + def _toPickleSerialization(self): + if (self._jrdd_deserializer == PickleSerializer() or + self._jrdd_deserializer == BatchedSerializer(PickleSerializer())): + return self + else: + return self._reserialize(BatchedSerializer(PickleSerializer(), 10)) + def id(self): """ A unique ID for this RDD (within its SparkContext). @@ -311,9 +318,9 @@ def map(self, f, preservesPartitioning=False): >>> sorted(rdd.map(lambda x: (x, 1)).collect()) [('a', 1), ('b', 1), ('c', 1)] """ - def func(split, iterator): + def func(_, iterator): return imap(f, iterator) - return PipelinedRDD(self, func, preservesPartitioning) + return self.mapPartitionsWithIndex(func, preservesPartitioning) def flatMap(self, f, preservesPartitioning=False): """ @@ -1030,6 +1037,113 @@ def first(self): """ return self.take(1)[0] + def saveAsNewAPIHadoopDataset(self, conf, keyConverter=None, valueConverter=None): + """ + Output a Python RDD of key-value pairs (of form C{RDD[(K, V)]}) to any Hadoop file + system, using the new Hadoop OutputFormat API (mapreduce package). Keys/values are + converted for output using either user specified converters or, by default, + L{org.apache.spark.api.python.JavaToWritableConverter}. + + @param conf: Hadoop job configuration, passed in as a dict + @param keyConverter: (None by default) + @param valueConverter: (None by default) + """ + jconf = self.ctx._dictToJavaMap(conf) + pickledRDD = self._toPickleSerialization() + batched = isinstance(pickledRDD._jrdd_deserializer, BatchedSerializer) + self.ctx._jvm.PythonRDD.saveAsHadoopDataset(pickledRDD._jrdd, batched, jconf, + keyConverter, valueConverter, True) + + def saveAsNewAPIHadoopFile(self, path, outputFormatClass, keyClass=None, valueClass=None, + keyConverter=None, valueConverter=None, conf=None): + """ + Output a Python RDD of key-value pairs (of form C{RDD[(K, V)]}) to any Hadoop file + system, using the new Hadoop OutputFormat API (mapreduce package). Key and value types + will be inferred if not specified. Keys and values are converted for output using either + user specified converters or L{org.apache.spark.api.python.JavaToWritableConverter}. The + C{conf} is applied on top of the base Hadoop conf associated with the SparkContext + of this RDD to create a merged Hadoop MapReduce job configuration for saving the data. + + @param path: path to Hadoop file + @param outputFormatClass: fully qualified classname of Hadoop OutputFormat + (e.g. "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat") + @param keyClass: fully qualified classname of key Writable class + (e.g. "org.apache.hadoop.io.IntWritable", None by default) + @param valueClass: fully qualified classname of value Writable class + (e.g. "org.apache.hadoop.io.Text", None by default) + @param keyConverter: (None by default) + @param valueConverter: (None by default) + @param conf: Hadoop job configuration, passed in as a dict (None by default) + """ + jconf = self.ctx._dictToJavaMap(conf) + pickledRDD = self._toPickleSerialization() + batched = isinstance(pickledRDD._jrdd_deserializer, BatchedSerializer) + self.ctx._jvm.PythonRDD.saveAsNewAPIHadoopFile(pickledRDD._jrdd, batched, path, + outputFormatClass, keyClass, valueClass, keyConverter, valueConverter, jconf) + + def saveAsHadoopDataset(self, conf, keyConverter=None, valueConverter=None): + """ + Output a Python RDD of key-value pairs (of form C{RDD[(K, V)]}) to any Hadoop file + system, using the old Hadoop OutputFormat API (mapred package). Keys/values are + converted for output using either user specified converters or, by default, + L{org.apache.spark.api.python.JavaToWritableConverter}. + + @param conf: Hadoop job configuration, passed in as a dict + @param keyConverter: (None by default) + @param valueConverter: (None by default) + """ + jconf = self.ctx._dictToJavaMap(conf) + pickledRDD = self._toPickleSerialization() + batched = isinstance(pickledRDD._jrdd_deserializer, BatchedSerializer) + self.ctx._jvm.PythonRDD.saveAsHadoopDataset(pickledRDD._jrdd, batched, jconf, + keyConverter, valueConverter, False) + + def saveAsHadoopFile(self, path, outputFormatClass, keyClass=None, valueClass=None, + keyConverter=None, valueConverter=None, conf=None, + compressionCodecClass=None): + """ + Output a Python RDD of key-value pairs (of form C{RDD[(K, V)]}) to any Hadoop file + system, using the old Hadoop OutputFormat API (mapred package). Key and value types + will be inferred if not specified. Keys and values are converted for output using either + user specified converters or L{org.apache.spark.api.python.JavaToWritableConverter}. The + C{conf} is applied on top of the base Hadoop conf associated with the SparkContext + of this RDD to create a merged Hadoop MapReduce job configuration for saving the data. + + @param path: path to Hadoop file + @param outputFormatClass: fully qualified classname of Hadoop OutputFormat + (e.g. "org.apache.hadoop.mapred.SequenceFileOutputFormat") + @param keyClass: fully qualified classname of key Writable class + (e.g. "org.apache.hadoop.io.IntWritable", None by default) + @param valueClass: fully qualified classname of value Writable class + (e.g. "org.apache.hadoop.io.Text", None by default) + @param keyConverter: (None by default) + @param valueConverter: (None by default) + @param conf: (None by default) + @param compressionCodecClass: (None by default) + """ + jconf = self.ctx._dictToJavaMap(conf) + pickledRDD = self._toPickleSerialization() + batched = isinstance(pickledRDD._jrdd_deserializer, BatchedSerializer) + self.ctx._jvm.PythonRDD.saveAsHadoopFile(pickledRDD._jrdd, batched, path, + outputFormatClass, keyClass, valueClass, keyConverter, valueConverter, + jconf, compressionCodecClass) + + def saveAsSequenceFile(self, path, compressionCodecClass=None): + """ + Output a Python RDD of key-value pairs (of form C{RDD[(K, V)]}) to any Hadoop file + system, using the L{org.apache.hadoop.io.Writable} types that we convert from the + RDD's key and value types. The mechanism is as follows: + 1. Pyrolite is used to convert pickled Python RDD into RDD of Java objects. + 2. Keys and values of this Java RDD are converted to Writables and written out. + + @param path: path to sequence file + @param compressionCodecClass: (None by default) + """ + pickledRDD = self._toPickleSerialization() + batched = isinstance(pickledRDD._jrdd_deserializer, BatchedSerializer) + self.ctx._jvm.PythonRDD.saveAsSequenceFile(pickledRDD._jrdd, batched, + path, compressionCodecClass) + def saveAsPickleFile(self, path, batchSize=10): """ Save this RDD as a SequenceFile of serialized objects. The serializer @@ -1070,7 +1184,7 @@ def func(split, iterator): if not isinstance(x, basestring): x = unicode(x) yield x.encode("utf-8") - keyed = PipelinedRDD(self, func) + keyed = self.mapPartitionsWithIndex(func) keyed._bypass_serializer = True keyed._jrdd.map(self.ctx._jvm.BytesToString()).saveAsTextFile(path) @@ -1268,7 +1382,7 @@ def add_shuffle_key(split, iterator): yield pack_long(split) yield outputSerializer.dumps(items) - keyed = PipelinedRDD(self, add_shuffle_key) + keyed = self.mapPartitionsWithIndex(add_shuffle_key) keyed._bypass_serializer = True with _JavaStackTrace(self.context) as st: pairRDD = self.ctx._jvm.PairwiseRDD( diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 03b31ae9624c2..a10f85b55ad30 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -65,6 +65,9 @@ import marshal import struct import sys +import types +import collections + from pyspark import cloudpickle @@ -267,6 +270,67 @@ def dumps(self, obj): return obj +# Hook namedtuple, make it picklable + +__cls = {} + + +def _restore(name, fields, value): + """ Restore an object of namedtuple""" + k = (name, fields) + cls = __cls.get(k) + if cls is None: + cls = collections.namedtuple(name, fields) + __cls[k] = cls + return cls(*value) + + +def _hack_namedtuple(cls): + """ Make class generated by namedtuple picklable """ + name = cls.__name__ + fields = cls._fields + def __reduce__(self): + return (_restore, (name, fields, tuple(self))) + cls.__reduce__ = __reduce__ + return cls + + +def _hijack_namedtuple(): + """ Hack namedtuple() to make it picklable """ + # hijack only one time + if hasattr(collections.namedtuple, "__hijack"): + return + + global _old_namedtuple # or it will put in closure + def _copy_func(f): + return types.FunctionType(f.func_code, f.func_globals, f.func_name, + f.func_defaults, f.func_closure) + + _old_namedtuple = _copy_func(collections.namedtuple) + + def namedtuple(name, fields, verbose=False, rename=False): + cls = _old_namedtuple(name, fields, verbose, rename) + return _hack_namedtuple(cls) + + # replace namedtuple with new one + collections.namedtuple.func_globals["_old_namedtuple"] = _old_namedtuple + collections.namedtuple.func_globals["_hack_namedtuple"] = _hack_namedtuple + collections.namedtuple.func_code = namedtuple.func_code + collections.namedtuple.__hijack = 1 + + # hack the cls already generated by namedtuple + # those created in other module can be pickled as normal, + # so only hack those in __main__ module + for n, o in sys.modules["__main__"].__dict__.iteritems(): + if (type(o) is type and o.__base__ is tuple + and hasattr(o, "_fields") + and "__reduce__" not in o.__dict__): + _hack_namedtuple(o) # hack inplace + + +_hijack_namedtuple() + + class PickleSerializer(FramedSerializer): """ Serializes objects using Python's cPickle serializer: diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index cb83e89176823..f1093701ddc89 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -15,12 +15,869 @@ # limitations under the License. # + +import sys +import types +import itertools +import warnings +import decimal +import datetime +import keyword +import warnings +from array import array +from operator import itemgetter + from pyspark.rdd import RDD, PipelinedRDD -from pyspark.serializers import BatchedSerializer, PickleSerializer +from pyspark.serializers import BatchedSerializer, PickleSerializer, CloudPickleSerializer + +from itertools import chain, ifilter, imap from py4j.protocol import Py4JError +from py4j.java_collections import ListConverter, MapConverter + + +__all__ = [ + "StringType", "BinaryType", "BooleanType", "TimestampType", "DecimalType", + "DoubleType", "FloatType", "ByteType", "IntegerType", "LongType", + "ShortType", "ArrayType", "MapType", "StructField", "StructType", + "SQLContext", "HiveContext", "LocalHiveContext", "TestHiveContext", + "SchemaRDD", "Row"] + + +class DataType(object): + """Spark SQL DataType""" + + def __repr__(self): + return self.__class__.__name__ + + def __hash__(self): + return hash(str(self)) + + def __eq__(self, other): + return (isinstance(other, self.__class__) and + self.__dict__ == other.__dict__) + + def __ne__(self, other): + return not self.__eq__(other) + + +class PrimitiveTypeSingleton(type): + """Metaclass for PrimitiveType""" + + _instances = {} + + def __call__(cls): + if cls not in cls._instances: + cls._instances[cls] = super(PrimitiveTypeSingleton, cls).__call__() + return cls._instances[cls] + + +class PrimitiveType(DataType): + """Spark SQL PrimitiveType""" + + __metaclass__ = PrimitiveTypeSingleton + + def __eq__(self, other): + # because they should be the same object + return self is other + + +class StringType(PrimitiveType): + """Spark SQL StringType + + The data type representing string values. + """ + + +class BinaryType(PrimitiveType): + """Spark SQL BinaryType + + The data type representing bytearray values. + """ + + +class BooleanType(PrimitiveType): + """Spark SQL BooleanType + + The data type representing bool values. + """ + + +class TimestampType(PrimitiveType): + """Spark SQL TimestampType + + The data type representing datetime.datetime values. + """ + + +class DecimalType(PrimitiveType): + """Spark SQL DecimalType + + The data type representing decimal.Decimal values. + """ + + +class DoubleType(PrimitiveType): + """Spark SQL DoubleType + + The data type representing float values. + """ + + +class FloatType(PrimitiveType): + """Spark SQL FloatType + + The data type representing single precision floating-point values. + """ + + +class ByteType(PrimitiveType): + """Spark SQL ByteType + + The data type representing int values with 1 singed byte. + """ + + +class IntegerType(PrimitiveType): + """Spark SQL IntegerType + + The data type representing int values. + """ + + +class LongType(PrimitiveType): + """Spark SQL LongType + + The data type representing long values. If the any value is + beyond the range of [-9223372036854775808, 9223372036854775807], + please use DecimalType. + """ + + +class ShortType(PrimitiveType): + """Spark SQL ShortType + + The data type representing int values with 2 signed bytes. + """ + + +class ArrayType(DataType): + """Spark SQL ArrayType + + The data type representing list values. An ArrayType object + comprises two fields, elementType (a DataType) and containsNull (a bool). + The field of elementType is used to specify the type of array elements. + The field of containsNull is used to specify if the array has None values. + + """ + + def __init__(self, elementType, containsNull=False): + """Creates an ArrayType + + :param elementType: the data type of elements. + :param containsNull: indicates whether the list contains None values. + + >>> ArrayType(StringType) == ArrayType(StringType, False) + True + >>> ArrayType(StringType, True) == ArrayType(StringType) + False + """ + self.elementType = elementType + self.containsNull = containsNull + + def __str__(self): + return "ArrayType(%s,%s)" % (self.elementType, + str(self.containsNull).lower()) + + +class MapType(DataType): + """Spark SQL MapType + + The data type representing dict values. A MapType object comprises + three fields, keyType (a DataType), valueType (a DataType) and + valueContainsNull (a bool). + + The field of keyType is used to specify the type of keys in the map. + The field of valueType is used to specify the type of values in the map. + The field of valueContainsNull is used to specify if values of this + map has None values. + + For values of a MapType column, keys are not allowed to have None values. + + """ + + def __init__(self, keyType, valueType, valueContainsNull=True): + """Creates a MapType + :param keyType: the data type of keys. + :param valueType: the data type of values. + :param valueContainsNull: indicates whether values contains + null values. + + >>> (MapType(StringType, IntegerType) + ... == MapType(StringType, IntegerType, True)) + True + >>> (MapType(StringType, IntegerType, False) + ... == MapType(StringType, FloatType)) + False + """ + self.keyType = keyType + self.valueType = valueType + self.valueContainsNull = valueContainsNull + + def __repr__(self): + return "MapType(%s,%s,%s)" % (self.keyType, self.valueType, + str(self.valueContainsNull).lower()) + + +class StructField(DataType): + """Spark SQL StructField + + Represents a field in a StructType. + A StructField object comprises three fields, name (a string), + dataType (a DataType) and nullable (a bool). The field of name + is the name of a StructField. The field of dataType specifies + the data type of a StructField. + + The field of nullable specifies if values of a StructField can + contain None values. + + """ + + def __init__(self, name, dataType, nullable): + """Creates a StructField + :param name: the name of this field. + :param dataType: the data type of this field. + :param nullable: indicates whether values of this field + can be null. + + >>> (StructField("f1", StringType, True) + ... == StructField("f1", StringType, True)) + True + >>> (StructField("f1", StringType, True) + ... == StructField("f2", StringType, True)) + False + """ + self.name = name + self.dataType = dataType + self.nullable = nullable + + def __repr__(self): + return "StructField(%s,%s,%s)" % (self.name, self.dataType, + str(self.nullable).lower()) + + +class StructType(DataType): + """Spark SQL StructType + + The data type representing rows. + A StructType object comprises a list of L{StructField}s. + + """ + + def __init__(self, fields): + """Creates a StructType + + >>> struct1 = StructType([StructField("f1", StringType, True)]) + >>> struct2 = StructType([StructField("f1", StringType, True)]) + >>> struct1 == struct2 + True + >>> struct1 = StructType([StructField("f1", StringType, True)]) + >>> struct2 = StructType([StructField("f1", StringType, True), + ... [StructField("f2", IntegerType, False)]]) + >>> struct1 == struct2 + False + """ + self.fields = fields + + def __repr__(self): + return ("StructType(List(%s))" % + ",".join(str(field) for field in self.fields)) + + +def _parse_datatype_list(datatype_list_string): + """Parses a list of comma separated data types.""" + index = 0 + datatype_list = [] + start = 0 + depth = 0 + while index < len(datatype_list_string): + if depth == 0 and datatype_list_string[index] == ",": + datatype_string = datatype_list_string[start:index].strip() + datatype_list.append(_parse_datatype_string(datatype_string)) + start = index + 1 + elif datatype_list_string[index] == "(": + depth += 1 + elif datatype_list_string[index] == ")": + depth -= 1 + + index += 1 + + # Handle the last data type + datatype_string = datatype_list_string[start:index].strip() + datatype_list.append(_parse_datatype_string(datatype_string)) + return datatype_list + + +_all_primitive_types = dict((k, v) for k, v in globals().iteritems() + if type(v) is PrimitiveTypeSingleton and v.__base__ == PrimitiveType) + + +def _parse_datatype_string(datatype_string): + """Parses the given data type string. + + >>> def check_datatype(datatype): + ... scala_datatype = sqlCtx._ssql_ctx.parseDataType(str(datatype)) + ... python_datatype = _parse_datatype_string( + ... scala_datatype.toString()) + ... return datatype == python_datatype + >>> all(check_datatype(cls()) for cls in _all_primitive_types.values()) + True + >>> # Simple ArrayType. + >>> simple_arraytype = ArrayType(StringType(), True) + >>> check_datatype(simple_arraytype) + True + >>> # Simple MapType. + >>> simple_maptype = MapType(StringType(), LongType()) + >>> check_datatype(simple_maptype) + True + >>> # Simple StructType. + >>> simple_structtype = StructType([ + ... StructField("a", DecimalType(), False), + ... StructField("b", BooleanType(), True), + ... StructField("c", LongType(), True), + ... StructField("d", BinaryType(), False)]) + >>> check_datatype(simple_structtype) + True + >>> # Complex StructType. + >>> complex_structtype = StructType([ + ... StructField("simpleArray", simple_arraytype, True), + ... StructField("simpleMap", simple_maptype, True), + ... StructField("simpleStruct", simple_structtype, True), + ... StructField("boolean", BooleanType(), False)]) + >>> check_datatype(complex_structtype) + True + >>> # Complex ArrayType. + >>> complex_arraytype = ArrayType(complex_structtype, True) + >>> check_datatype(complex_arraytype) + True + >>> # Complex MapType. + >>> complex_maptype = MapType(complex_structtype, + ... complex_arraytype, False) + >>> check_datatype(complex_maptype) + True + """ + index = datatype_string.find("(") + if index == -1: + # It is a primitive type. + index = len(datatype_string) + type_or_field = datatype_string[:index] + rest_part = datatype_string[index + 1:len(datatype_string) - 1].strip() + + if type_or_field in _all_primitive_types: + return _all_primitive_types[type_or_field]() + + elif type_or_field == "ArrayType": + last_comma_index = rest_part.rfind(",") + containsNull = True + if rest_part[last_comma_index + 1:].strip().lower() == "false": + containsNull = False + elementType = _parse_datatype_string( + rest_part[:last_comma_index].strip()) + return ArrayType(elementType, containsNull) + + elif type_or_field == "MapType": + last_comma_index = rest_part.rfind(",") + valueContainsNull = True + if rest_part[last_comma_index + 1:].strip().lower() == "false": + valueContainsNull = False + keyType, valueType = _parse_datatype_list( + rest_part[:last_comma_index].strip()) + return MapType(keyType, valueType, valueContainsNull) + + elif type_or_field == "StructField": + first_comma_index = rest_part.find(",") + name = rest_part[:first_comma_index].strip() + last_comma_index = rest_part.rfind(",") + nullable = True + if rest_part[last_comma_index + 1:].strip().lower() == "false": + nullable = False + dataType = _parse_datatype_string( + rest_part[first_comma_index + 1:last_comma_index].strip()) + return StructField(name, dataType, nullable) + + elif type_or_field == "StructType": + # rest_part should be in the format like + # List(StructField(field1,IntegerType,false)). + field_list_string = rest_part[rest_part.find("(") + 1:-1] + fields = _parse_datatype_list(field_list_string) + return StructType(fields) + + +# Mapping Python types to Spark SQL DateType +_type_mappings = { + bool: BooleanType, + int: IntegerType, + long: LongType, + float: DoubleType, + str: StringType, + unicode: StringType, + decimal.Decimal: DecimalType, + datetime.datetime: TimestampType, + datetime.date: TimestampType, + datetime.time: TimestampType, +} + + +def _infer_type(obj): + """Infer the DataType from obj""" + if obj is None: + raise ValueError("Can not infer type for None") + + dataType = _type_mappings.get(type(obj)) + if dataType is not None: + return dataType() + + if isinstance(obj, dict): + if not obj: + raise ValueError("Can not infer type for empty dict") + key, value = obj.iteritems().next() + return MapType(_infer_type(key), _infer_type(value), True) + elif isinstance(obj, (list, array)): + if not obj: + raise ValueError("Can not infer type for empty list/array") + return ArrayType(_infer_type(obj[0]), True) + else: + try: + return _infer_schema(obj) + except ValueError: + raise ValueError("not supported type: %s" % type(obj)) + + +def _infer_schema(row): + """Infer the schema from dict/namedtuple/object""" + if isinstance(row, dict): + items = sorted(row.items()) + + elif isinstance(row, tuple): + if hasattr(row, "_fields"): # namedtuple + items = zip(row._fields, tuple(row)) + elif hasattr(row, "__FIELDS__"): # Row + items = zip(row.__FIELDS__, tuple(row)) + elif all(isinstance(x, tuple) and len(x) == 2 for x in row): + items = row + else: + raise ValueError("Can't infer schema from tuple") + + elif hasattr(row, "__dict__"): # object + items = sorted(row.__dict__.items()) + + else: + raise ValueError("Can not infer schema for type: %s" % type(row)) + + fields = [StructField(k, _infer_type(v), True) for k, v in items] + return StructType(fields) + + +def _create_converter(obj, dataType): + """Create an converter to drop the names of fields in obj """ + if not _has_struct(dataType): + return lambda x: x + + elif isinstance(dataType, ArrayType): + conv = _create_converter(obj[0], dataType.elementType) + return lambda row: map(conv, row) + + elif isinstance(dataType, MapType): + value = obj.values()[0] + conv = _create_converter(value, dataType.valueType) + return lambda row: dict((k, conv(v)) for k, v in row.iteritems()) + + # dataType must be StructType + names = [f.name for f in dataType.fields] + + if isinstance(obj, dict): + conv = lambda o: tuple(o.get(n) for n in names) + + elif isinstance(obj, tuple): + if hasattr(obj, "_fields"): # namedtuple + conv = tuple + elif hasattr(obj, "__FIELDS__"): + conv = tuple + elif all(isinstance(x, tuple) and len(x) == 2 for x in obj): + conv = lambda o: tuple(v for k, v in o) + else: + raise ValueError("unexpected tuple") + + elif hasattr(obj, "__dict__"): # object + conv = lambda o: [o.__dict__.get(n, None) for n in names] + + nested = any(_has_struct(f.dataType) for f in dataType.fields) + if not nested: + return conv + + row = conv(obj) + convs = [_create_converter(v, f.dataType) + for v, f in zip(row, dataType.fields)] -__all__ = ["SQLContext", "HiveContext", "LocalHiveContext", "TestHiveContext", "SchemaRDD", "Row"] + def nested_conv(row): + return tuple(f(v) for f, v in zip(convs, conv(row))) + + return nested_conv + + +def _drop_schema(rows, schema): + """ all the names of fields, becoming tuples""" + iterator = iter(rows) + row = iterator.next() + converter = _create_converter(row, schema) + yield converter(row) + for i in iterator: + yield converter(i) + + +_BRACKETS = {'(': ')', '[': ']', '{': '}'} + + +def _split_schema_abstract(s): + """ + split the schema abstract into fields + + >>> _split_schema_abstract("a b c") + ['a', 'b', 'c'] + >>> _split_schema_abstract("a(a b)") + ['a(a b)'] + >>> _split_schema_abstract("a b[] c{a b}") + ['a', 'b[]', 'c{a b}'] + >>> _split_schema_abstract(" ") + [] + """ + + r = [] + w = '' + brackets = [] + for c in s: + if c == ' ' and not brackets: + if w: + r.append(w) + w = '' + else: + w += c + if c in _BRACKETS: + brackets.append(c) + elif c in _BRACKETS.values(): + if not brackets or c != _BRACKETS[brackets.pop()]: + raise ValueError("unexpected " + c) + + if brackets: + raise ValueError("brackets not closed: %s" % brackets) + if w: + r.append(w) + return r + + +def _parse_field_abstract(s): + """ + Parse a field in schema abstract + + >>> _parse_field_abstract("a") + StructField(a,None,true) + >>> _parse_field_abstract("b(c d)") + StructField(b,StructType(...c,None,true),StructField(d... + >>> _parse_field_abstract("a[]") + StructField(a,ArrayType(None,true),true) + >>> _parse_field_abstract("a{[]}") + StructField(a,MapType(None,ArrayType(None,true),true),true) + """ + if set(_BRACKETS.keys()) & set(s): + idx = min((s.index(c) for c in _BRACKETS if c in s)) + name = s[:idx] + return StructField(name, _parse_schema_abstract(s[idx:]), True) + else: + return StructField(s, None, True) + + +def _parse_schema_abstract(s): + """ + parse abstract into schema + + >>> _parse_schema_abstract("a b c") + StructType...a...b...c... + >>> _parse_schema_abstract("a[b c] b{}") + StructType...a,ArrayType...b...c...b,MapType... + >>> _parse_schema_abstract("c{} d{a b}") + StructType...c,MapType...d,MapType...a...b... + >>> _parse_schema_abstract("a b(t)").fields[1] + StructField(b,StructType(List(StructField(t,None,true))),true) + """ + s = s.strip() + if not s: + return + + elif s.startswith('('): + return _parse_schema_abstract(s[1:-1]) + + elif s.startswith('['): + return ArrayType(_parse_schema_abstract(s[1:-1]), True) + + elif s.startswith('{'): + return MapType(None, _parse_schema_abstract(s[1:-1])) + + parts = _split_schema_abstract(s) + fields = [_parse_field_abstract(p) for p in parts] + return StructType(fields) + + +def _infer_schema_type(obj, dataType): + """ + Fill the dataType with types infered from obj + + >>> schema = _parse_schema_abstract("a b c") + >>> row = (1, 1.0, "str") + >>> _infer_schema_type(row, schema) + StructType...IntegerType...DoubleType...StringType... + >>> row = [[1], {"key": (1, 2.0)}] + >>> schema = _parse_schema_abstract("a[] b{c d}") + >>> _infer_schema_type(row, schema) + StructType...a,ArrayType...b,MapType(StringType,...c,IntegerType... + """ + if dataType is None: + return _infer_type(obj) + + if not obj: + raise ValueError("Can not infer type from empty value") + + if isinstance(dataType, ArrayType): + eType = _infer_schema_type(obj[0], dataType.elementType) + return ArrayType(eType, True) + + elif isinstance(dataType, MapType): + k, v = obj.iteritems().next() + return MapType(_infer_type(k), + _infer_schema_type(v, dataType.valueType)) + + elif isinstance(dataType, StructType): + fs = dataType.fields + assert len(fs) == len(obj), \ + "Obj(%s) have different length with fields(%s)" % (obj, fs) + fields = [StructField(f.name, _infer_schema_type(o, f.dataType), True) + for o, f in zip(obj, fs)] + return StructType(fields) + + else: + raise ValueError("Unexpected dataType: %s" % dataType) + + +_acceptable_types = { + BooleanType: (bool,), + ByteType: (int, long), + ShortType: (int, long), + IntegerType: (int, long), + LongType: (long,), + FloatType: (float,), + DoubleType: (float,), + DecimalType: (decimal.Decimal,), + StringType: (str, unicode), + TimestampType: (datetime.datetime,), + ArrayType: (list, tuple, array), + MapType: (dict,), + StructType: (tuple, list), +} + +def _verify_type(obj, dataType): + """ + Verify the type of obj against dataType, raise an exception if + they do not match. + + >>> _verify_type(None, StructType([])) + >>> _verify_type("", StringType()) + >>> _verify_type(0, IntegerType()) + >>> _verify_type(range(3), ArrayType(ShortType())) + >>> _verify_type(set(), ArrayType(StringType())) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + TypeError:... + >>> _verify_type({}, MapType(StringType(), IntegerType())) + >>> _verify_type((), StructType([])) + >>> _verify_type([], StructType([])) + >>> _verify_type([1], StructType([])) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + ValueError:... + """ + # all objects are nullable + if obj is None: + return + + _type = type(dataType) + if _type not in _acceptable_types: + return + + if type(obj) not in _acceptable_types[_type]: + raise TypeError("%s can not accept abject in type %s" + % (dataType, type(obj))) + + if isinstance(dataType, ArrayType): + for i in obj: + _verify_type(i, dataType.elementType) + + elif isinstance(dataType, MapType): + for k, v in obj.iteritems(): + _verify_type(k, dataType.keyType) + _verify_type(v, dataType.valueType) + + elif isinstance(dataType, StructType): + if len(obj) != len(dataType.fields): + raise ValueError("Length of object (%d) does not match with" + "length of fields (%d)" % (len(obj), len(dataType.fields))) + for v, f in zip(obj, dataType.fields): + _verify_type(v, f.dataType) + + +_cached_cls = {} + + +def _restore_object(dataType, obj): + """ Restore object during unpickling. """ + # use id(dataType) as key to speed up lookup in dict + # Because of batched pickling, dataType will be the + # same object in mose cases. + k = id(dataType) + cls = _cached_cls.get(k) + if cls is None: + # use dataType as key to avoid create multiple class + cls = _cached_cls.get(dataType) + if cls is None: + cls = _create_cls(dataType) + _cached_cls[dataType] = cls + _cached_cls[k] = cls + return cls(obj) + + +def _create_object(cls, v): + """ Create an customized object with class `cls`. """ + return cls(v) if v is not None else v + + +def _create_getter(dt, i): + """ Create a getter for item `i` with schema """ + cls = _create_cls(dt) + + def getter(self): + return _create_object(cls, self[i]) + + return getter + + +def _has_struct(dt): + """Return whether `dt` is or has StructType in it""" + if isinstance(dt, StructType): + return True + elif isinstance(dt, ArrayType): + return _has_struct(dt.elementType) + elif isinstance(dt, MapType): + return _has_struct(dt.valueType) + return False + + +def _create_properties(fields): + """Create properties according to fields""" + ps = {} + for i, f in enumerate(fields): + name = f.name + if (name.startswith("__") and name.endswith("__") + or keyword.iskeyword(name)): + warnings.warn("field name %s can not be accessed in Python," + "use position to access it instead" % name) + if _has_struct(f.dataType): + # delay creating object until accessing it + getter = _create_getter(f.dataType, i) + else: + getter = itemgetter(i) + ps[name] = property(getter) + return ps + + +def _create_cls(dataType): + """ + Create an class by dataType + + The created class is similar to namedtuple, but can have nested schema. + + >>> schema = _parse_schema_abstract("a b c") + >>> row = (1, 1.0, "str") + >>> schema = _infer_schema_type(row, schema) + >>> obj = _create_cls(schema)(row) + >>> import pickle + >>> pickle.loads(pickle.dumps(obj)) + Row(a=1, b=1.0, c='str') + + >>> row = [[1], {"key": (1, 2.0)}] + >>> schema = _parse_schema_abstract("a[] b{c d}") + >>> schema = _infer_schema_type(row, schema) + >>> obj = _create_cls(schema)(row) + >>> pickle.loads(pickle.dumps(obj)) + Row(a=[1], b={'key': Row(c=1, d=2.0)}) + """ + + if isinstance(dataType, ArrayType): + cls = _create_cls(dataType.elementType) + + class List(list): + + def __getitem__(self, i): + # create object with datetype + return _create_object(cls, list.__getitem__(self, i)) + + def __repr__(self): + # call collect __repr__ for nested objects + return "[%s]" % (", ".join(repr(self[i]) + for i in range(len(self)))) + + def __reduce__(self): + return list.__reduce__(self) + + return List + + elif isinstance(dataType, MapType): + vcls = _create_cls(dataType.valueType) + + class Dict(dict): + + def __getitem__(self, k): + # create object with datetype + return _create_object(vcls, dict.__getitem__(self, k)) + + def __repr__(self): + # call collect __repr__ for nested objects + return "{%s}" % (", ".join("%r: %r" % (k, self[k]) + for k in self)) + + def __reduce__(self): + return dict.__reduce__(self) + + return Dict + + elif not isinstance(dataType, StructType): + raise Exception("unexpected data type: %s" % dataType) + + class Row(tuple): + """ Row in SchemaRDD """ + __DATATYPE__ = dataType + __FIELDS__ = tuple(f.name for f in dataType.fields) + __slots__ = () + + # create property for fast access + locals().update(_create_properties(dataType.fields)) + + def __repr__(self): + # call collect __repr__ for nested objects + return ("Row(%s)" % ", ".join("%s=%r" % (n, getattr(self, n)) + for n in self.__FIELDS__)) + + def __reduce__(self): + return (_restore_object, (self.__DATATYPE__, tuple(self))) + + return Row class SQLContext: @@ -39,7 +896,7 @@ def __init__(self, sparkContext, sqlContext=None): >>> sqlCtx.inferSchema(srdd) # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): ... - ValueError:... + TypeError:... >>> bad_rdd = sc.parallelize([1,2,3]) >>> sqlCtx.inferSchema(bad_rdd) # doctest: +IGNORE_EXCEPTION_DETAIL @@ -47,17 +904,23 @@ def __init__(self, sparkContext, sqlContext=None): ... ValueError:... - >>> allTypes = sc.parallelize([{"int" : 1, "string" : "string", "double" : 1.0, "long": 1L, - ... "boolean" : True}]) - >>> srdd = sqlCtx.inferSchema(allTypes).map(lambda x: (x.int, x.string, x.double, x.long, - ... x.boolean)) - >>> srdd.collect()[0] - (1, u'string', 1.0, 1, True) + >>> from datetime import datetime + >>> allTypes = sc.parallelize([Row(i=1, s="string", d=1.0, l=1L, + ... b=True, list=[1, 2, 3], dict={"s": 0}, row=Row(a=1), + ... time=datetime(2014, 8, 1, 14, 1, 5))]) + >>> srdd = sqlCtx.inferSchema(allTypes) + >>> srdd.registerTempTable("allTypes") + >>> sqlCtx.sql('select i+1, d+1, not b, list[1], dict["s"], time, row.a ' + ... 'from allTypes where b and i > 0').collect() + [Row(c0=2, c1=2.0, c2=False, c3=2, c4=0...8, 1, 14, 1, 5), a=1)] + >>> srdd.map(lambda x: (x.i, x.s, x.d, x.l, x.b, x.time, + ... x.row.a, x.list)).collect() + [(1, u'string', 1.0, 1, True, ...(2014, 8, 1, 14, 1, 5), 1, [1, 2, 3])] """ self._sc = sparkContext self._jsc = self._sc._jsc self._jvm = self._sc._jvm - self._pythonToJavaMap = self._jvm.PythonRDD.pythonToJavaMap + self._pythonToJava = self._jvm.PythonRDD.pythonToJavaArray if sqlContext: self._scala_SQLContext = sqlContext @@ -73,38 +936,168 @@ def _ssql_ctx(self): self._scala_SQLContext = self._jvm.SQLContext(self._jsc.sc()) return self._scala_SQLContext + def registerFunction(self, name, f, returnType=StringType()): + """Registers a lambda function as a UDF so it can be used in SQL statements. + + In addition to a name and the function itself, the return type can be optionally specified. + When the return type is not given it default to a string and conversion will automatically + be done. For any other return type, the produced object must match the specified type. + + >>> sqlCtx.registerFunction("stringLengthString", lambda x: len(x)) + >>> sqlCtx.sql("SELECT stringLengthString('test')").collect() + [Row(c0=u'4')] + >>> sqlCtx.registerFunction("stringLengthInt", lambda x: len(x), IntegerType()) + >>> sqlCtx.sql("SELECT stringLengthInt('test')").collect() + [Row(c0=4)] + >>> sqlCtx.registerFunction("twoArgs", lambda x, y: len(x) + y, IntegerType()) + >>> sqlCtx.sql("SELECT twoArgs('test', 1)").collect() + [Row(c0=5)] + """ + func = lambda _, it: imap(lambda x: f(*x), it) + command = (func, + BatchedSerializer(PickleSerializer(), 1024), + BatchedSerializer(PickleSerializer(), 1024)) + env = MapConverter().convert(self._sc.environment, + self._sc._gateway._gateway_client) + includes = ListConverter().convert(self._sc._python_includes, + self._sc._gateway._gateway_client) + self._ssql_ctx.registerPython(name, + bytearray(CloudPickleSerializer().dumps(command)), + env, + includes, + self._sc.pythonExec, + self._sc._javaAccumulator, + str(returnType)) + def inferSchema(self, rdd): - """Infer and apply a schema to an RDD of L{dict}s. + """Infer and apply a schema to an RDD of L{Row}s. + + We peek at the first row of the RDD to determine the fields' names + and types. Nested collections are supported, which include array, + dict, list, Row, tuple, namedtuple, or object. + + All the rows in `rdd` should have the same type with the first one, + or it will cause runtime exceptions. - We peek at the first row of the RDD to determine the fields names - and types, and then use that to extract all the dictionaries. Nested - collections are supported, which include array, dict, list, set, and - tuple. + Each row could be L{pyspark.sql.Row} object or namedtuple or objects, + using dict is deprecated. + >>> rdd = sc.parallelize( + ... [Row(field1=1, field2="row1"), + ... Row(field1=2, field2="row2"), + ... Row(field1=3, field2="row3")]) >>> srdd = sqlCtx.inferSchema(rdd) - >>> srdd.collect() == [{"field1" : 1, "field2" : "row1"}, {"field1" : 2, "field2": "row2"}, - ... {"field1" : 3, "field2": "row3"}] - True + >>> srdd.collect()[0] + Row(field1=1, field2=u'row1') - >>> from array import array + >>> NestedRow = Row("f1", "f2") + >>> nestedRdd1 = sc.parallelize([ + ... NestedRow(array('i', [1, 2]), {"row1": 1.0}), + ... NestedRow(array('i', [2, 3]), {"row2": 2.0})]) >>> srdd = sqlCtx.inferSchema(nestedRdd1) - >>> srdd.collect() == [{"f1" : array('i', [1, 2]), "f2" : {"row1" : 1.0}}, - ... {"f1" : array('i', [2, 3]), "f2" : {"row2" : 2.0}}] - True + >>> srdd.collect() + [Row(f1=[1, 2], f2={u'row1': 1.0}), ..., f2={u'row2': 2.0})] + >>> nestedRdd2 = sc.parallelize([ + ... NestedRow([[1, 2], [2, 3]], [1, 2]), + ... NestedRow([[2, 3], [3, 4]], [2, 3])]) >>> srdd = sqlCtx.inferSchema(nestedRdd2) - >>> srdd.collect() == [{"f1" : [[1, 2], [2, 3]], "f2" : set([1, 2]), "f3" : (1, 2)}, - ... {"f1" : [[2, 3], [3, 4]], "f2" : set([2, 3]), "f3" : (2, 3)}] - True + >>> srdd.collect() + [Row(f1=[[1, 2], [2, 3]], f2=[1, 2]), ..., f2=[2, 3])] """ - if (rdd.__class__ is SchemaRDD): - raise ValueError("Cannot apply schema to %s" % SchemaRDD.__name__) - elif not isinstance(rdd.first(), dict): - raise ValueError("Only RDDs with dictionaries can be converted to %s: %s" % - (SchemaRDD.__name__, rdd.first())) - jrdd = self._pythonToJavaMap(rdd._jrdd) - srdd = self._ssql_ctx.inferSchema(jrdd.rdd()) + if isinstance(rdd, SchemaRDD): + raise TypeError("Cannot apply schema to SchemaRDD") + + first = rdd.first() + if not first: + raise ValueError("The first row in RDD is empty, " + "can not infer schema") + if type(first) is dict: + warnings.warn("Using RDD of dict to inferSchema is deprecated") + + schema = _infer_schema(first) + rdd = rdd.mapPartitions(lambda rows: _drop_schema(rows, schema)) + return self.applySchema(rdd, schema) + + def applySchema(self, rdd, schema): + """ + Applies the given schema to the given RDD of L{tuple} or L{list}s. + + These tuples or lists can contain complex nested structures like + lists, maps or nested rows. + + The schema should be a StructType. + + It is important that the schema matches the types of the objects + in each row or exceptions could be thrown at runtime. + + >>> rdd2 = sc.parallelize([(1, "row1"), (2, "row2"), (3, "row3")]) + >>> schema = StructType([StructField("field1", IntegerType(), False), + ... StructField("field2", StringType(), False)]) + >>> srdd = sqlCtx.applySchema(rdd2, schema) + >>> sqlCtx.registerRDDAsTable(srdd, "table1") + >>> srdd2 = sqlCtx.sql("SELECT * from table1") + >>> srdd2.collect() + [Row(field1=1, field2=u'row1'),..., Row(field1=3, field2=u'row3')] + + >>> from datetime import datetime + >>> rdd = sc.parallelize([(127, -128L, -32768, 32767, 2147483647L, 1.0, + ... datetime(2010, 1, 1, 1, 1, 1), + ... {"a": 1}, (2,), [1, 2, 3], None)]) + >>> schema = StructType([ + ... StructField("byte1", ByteType(), False), + ... StructField("byte2", ByteType(), False), + ... StructField("short1", ShortType(), False), + ... StructField("short2", ShortType(), False), + ... StructField("int", IntegerType(), False), + ... StructField("float", FloatType(), False), + ... StructField("time", TimestampType(), False), + ... StructField("map", + ... MapType(StringType(), IntegerType(), False), False), + ... StructField("struct", + ... StructType([StructField("b", ShortType(), False)]), False), + ... StructField("list", ArrayType(ByteType(), False), False), + ... StructField("null", DoubleType(), True)]) + >>> srdd = sqlCtx.applySchema(rdd, schema) + >>> results = srdd.map( + ... lambda x: (x.byte1, x.byte2, x.short1, x.short2, x.int, x.float, x.time, + ... x.map["a"], x.struct.b, x.list, x.null)) + >>> results.collect()[0] + (127, -128, -32768, 32767, 2147483647, 1.0, ...(2010, 1, 1, 1, 1, 1), 1, 2, [1, 2, 3], None) + + >>> srdd.registerTempTable("table2") + >>> sqlCtx.sql( + ... "SELECT byte1 - 1 AS byte1, byte2 + 1 AS byte2, " + + ... "short1 + 1 AS short1, short2 - 1 AS short2, int - 1 AS int, " + + ... "float + 1.1 as float FROM table2").collect() + [Row(byte1=126, byte2=-127, short1=-32767, short2=32766, int=2147483646, float=2.1)] + + >>> rdd = sc.parallelize([(127, -32768, 1.0, + ... datetime(2010, 1, 1, 1, 1, 1), + ... {"a": 1}, (2,), [1, 2, 3])]) + >>> abstract = "byte short float time map{} struct(b) list[]" + >>> schema = _parse_schema_abstract(abstract) + >>> typedSchema = _infer_schema_type(rdd.first(), schema) + >>> srdd = sqlCtx.applySchema(rdd, typedSchema) + >>> srdd.collect() + [Row(byte=127, short=-32768, float=1.0, time=..., list=[1, 2, 3])] + """ + + if isinstance(rdd, SchemaRDD): + raise TypeError("Cannot apply schema to SchemaRDD") + + if not isinstance(schema, StructType): + raise TypeError("schema should be StructType") + + # take the first few rows to verify schema + rows = rdd.take(10) + for row in rows: + _verify_type(row, schema) + + batched = isinstance(rdd._jrdd_deserializer, BatchedSerializer) + jrdd = self._pythonToJava(rdd._jrdd, batched) + srdd = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), str(schema)) return SchemaRDD(srdd, self) def registerRDDAsTable(self, rdd, tableName): @@ -137,10 +1130,16 @@ def parquetFile(self, path): jschema_rdd = self._ssql_ctx.parquetFile(path) return SchemaRDD(jschema_rdd, self) - def jsonFile(self, path): - """Loads a text file storing one JSON object per line, - returning the result as a L{SchemaRDD}. - It goes through the entire dataset once to determine the schema. + def jsonFile(self, path, schema=None): + """ + Loads a text file storing one JSON object per line as a + L{SchemaRDD}. + + If the schema is provided, applies the given schema to this + JSON dataset. + + Otherwise, it goes through the entire dataset once to determine + the schema. >>> import tempfile, shutil >>> jsonFile = tempfile.mkdtemp() @@ -149,42 +1148,104 @@ def jsonFile(self, path): >>> for json in jsonStrings: ... print>>ofn, json >>> ofn.close() - >>> srdd = sqlCtx.jsonFile(jsonFile) - >>> sqlCtx.registerRDDAsTable(srdd, "table1") + >>> srdd1 = sqlCtx.jsonFile(jsonFile) + >>> sqlCtx.registerRDDAsTable(srdd1, "table1") >>> srdd2 = sqlCtx.sql( - ... "SELECT field1 AS f1, field2 as f2, field3 as f3, field6 as f4 from table1") - >>> srdd2.collect() == [ - ... {"f1":1, "f2":"row1", "f3":{"field4":11, "field5": None}, "f4":None}, - ... {"f1":2, "f2":None, "f3":{"field4":22, "field5": [10, 11]}, "f4":[{"field7": "row2"}]}, - ... {"f1":None, "f2":"row3", "f3":{"field4":33, "field5": []}, "f4":None}] - True + ... "SELECT field1 AS f1, field2 as f2, field3 as f3, " + ... "field6 as f4 from table1") + >>> for r in srdd2.collect(): + ... print r + Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None) + Row(f1=2, f2=None, f3=Row(field4=22,..., f4=[Row(field7=u'row2')]) + Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None) + >>> srdd3 = sqlCtx.jsonFile(jsonFile, srdd1.schema()) + >>> sqlCtx.registerRDDAsTable(srdd3, "table2") + >>> srdd4 = sqlCtx.sql( + ... "SELECT field1 AS f1, field2 as f2, field3 as f3, " + ... "field6 as f4 from table2") + >>> for r in srdd4.collect(): + ... print r + Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None) + Row(f1=2, f2=None, f3=Row(field4=22,..., f4=[Row(field7=u'row2')]) + Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None) + >>> schema = StructType([ + ... StructField("field2", StringType(), True), + ... StructField("field3", + ... StructType([ + ... StructField("field5", + ... ArrayType(IntegerType(), False), True)]), False)]) + >>> srdd5 = sqlCtx.jsonFile(jsonFile, schema) + >>> sqlCtx.registerRDDAsTable(srdd5, "table3") + >>> srdd6 = sqlCtx.sql( + ... "SELECT field2 AS f1, field3.field5 as f2, " + ... "field3.field5[0] as f3 from table3") + >>> srdd6.collect() + [Row(f1=u'row1', f2=None, f3=None)...Row(f1=u'row3', f2=[], f3=None)] """ - jschema_rdd = self._ssql_ctx.jsonFile(path) + if schema is None: + jschema_rdd = self._ssql_ctx.jsonFile(path) + else: + scala_datatype = self._ssql_ctx.parseDataType(str(schema)) + jschema_rdd = self._ssql_ctx.jsonFile(path, scala_datatype) return SchemaRDD(jschema_rdd, self) - def jsonRDD(self, rdd): - """Loads an RDD storing one JSON object per string, returning the result as a L{SchemaRDD}. - It goes through the entire dataset once to determine the schema. + def jsonRDD(self, rdd, schema=None): + """Loads an RDD storing one JSON object per string as a L{SchemaRDD}. - >>> srdd = sqlCtx.jsonRDD(json) - >>> sqlCtx.registerRDDAsTable(srdd, "table1") + If the schema is provided, applies the given schema to this + JSON dataset. + + Otherwise, it goes through the entire dataset once to determine + the schema. + + >>> srdd1 = sqlCtx.jsonRDD(json) + >>> sqlCtx.registerRDDAsTable(srdd1, "table1") >>> srdd2 = sqlCtx.sql( - ... "SELECT field1 AS f1, field2 as f2, field3 as f3, field6 as f4 from table1") - >>> srdd2.collect() == [ - ... {"f1":1, "f2":"row1", "f3":{"field4":11, "field5": None}, "f4":None}, - ... {"f1":2, "f2":None, "f3":{"field4":22, "field5": [10, 11]}, "f4":[{"field7": "row2"}]}, - ... {"f1":None, "f2":"row3", "f3":{"field4":33, "field5": []}, "f4":None}] - True + ... "SELECT field1 AS f1, field2 as f2, field3 as f3, " + ... "field6 as f4 from table1") + >>> for r in srdd2.collect(): + ... print r + Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None) + Row(f1=2, f2=None, f3=Row(field4=22..., f4=[Row(field7=u'row2')]) + Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None) + >>> srdd3 = sqlCtx.jsonRDD(json, srdd1.schema()) + >>> sqlCtx.registerRDDAsTable(srdd3, "table2") + >>> srdd4 = sqlCtx.sql( + ... "SELECT field1 AS f1, field2 as f2, field3 as f3, " + ... "field6 as f4 from table2") + >>> for r in srdd4.collect(): + ... print r + Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None) + Row(f1=2, f2=None, f3=Row(field4=22..., f4=[Row(field7=u'row2')]) + Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None) + >>> schema = StructType([ + ... StructField("field2", StringType(), True), + ... StructField("field3", + ... StructType([ + ... StructField("field5", + ... ArrayType(IntegerType(), False), True)]), False)]) + >>> srdd5 = sqlCtx.jsonRDD(json, schema) + >>> sqlCtx.registerRDDAsTable(srdd5, "table3") + >>> srdd6 = sqlCtx.sql( + ... "SELECT field2 AS f1, field3.field5 as f2, " + ... "field3.field5[0] as f3 from table3") + >>> srdd6.collect() + [Row(f1=u'row1', f2=None,...Row(f1=u'row3', f2=[], f3=None)] """ - def func(split, iterator): + + def func(iterator): for x in iterator: if not isinstance(x, basestring): x = unicode(x) yield x.encode("utf-8") - keyed = PipelinedRDD(rdd, func) + keyed = rdd.mapPartitions(func) keyed._bypass_serializer = True jrdd = keyed._jrdd.map(self._jvm.BytesToString()) - jschema_rdd = self._ssql_ctx.jsonRDD(jrdd.rdd()) + if schema is None: + jschema_rdd = self._ssql_ctx.jsonRDD(jrdd.rdd()) + else: + scala_datatype = self._ssql_ctx.parseDataType(str(schema)) + jschema_rdd = self._ssql_ctx.jsonRDD(jrdd.rdd(), scala_datatype) return SchemaRDD(jschema_rdd, self) def sql(self, sqlQuery): @@ -193,9 +1254,8 @@ def sql(self, sqlQuery): >>> srdd = sqlCtx.inferSchema(rdd) >>> sqlCtx.registerRDDAsTable(srdd, "table1") >>> srdd2 = sqlCtx.sql("SELECT field1 AS f1, field2 as f2 from table1") - >>> srdd2.collect() == [{"f1" : 1, "f2" : "row1"}, {"f1" : 2, "f2": "row2"}, - ... {"f1" : 3, "f2": "row3"}] - True + >>> srdd2.collect() + [Row(f1=1, f2=u'row1'), Row(f1=2, f2=u'row2'), Row(f1=3, f2=u'row3')] """ return SchemaRDD(self._ssql_ctx.sql(sqlQuery), self) @@ -233,7 +1293,8 @@ def _ssql_ctx(self): self._scala_HiveContext = self._get_hive_ctx() return self._scala_HiveContext except Py4JError as e: - raise Exception("You must build Spark with Hive. Export 'SPARK_HIVE=true' and run " + raise Exception("You must build Spark with Hive. " + "Export 'SPARK_HIVE=true' and run " "sbt/sbt assembly", e) def _get_hive_ctx(self): @@ -241,14 +1302,20 @@ def _get_hive_ctx(self): def hiveql(self, hqlQuery): """ - Runs a query expressed in HiveQL, returning the result as a L{SchemaRDD}. + DEPRECATED: Use sql() """ + warnings.warn("hiveql() is deprecated as the sql function now parses using HiveQL by" + + "default. The SQL dialect for parsing can be set using 'spark.sql.dialect'", + DeprecationWarning) return SchemaRDD(self._ssql_ctx.hiveql(hqlQuery), self) def hql(self, hqlQuery): """ - Runs a query expressed in HiveQL, returning the result as a L{SchemaRDD}. + DEPRECATED: Use sql() """ + warnings.warn("hql() is deprecated as the sql function now parses using HiveQL by" + + "default. The SQL dialect for parsing can be set using 'spark.sql.dialect'", + DeprecationWarning) return self.hiveql(hqlQuery) @@ -261,13 +1328,17 @@ class LocalHiveContext(HiveContext): >>> import os >>> hiveCtx = LocalHiveContext(sc) >>> try: - ... supress = hiveCtx.hql("DROP TABLE src") + ... supress = hiveCtx.sql("DROP TABLE src") ... except Exception: ... pass - >>> kv1 = os.path.join(os.environ["SPARK_HOME"], 'examples/src/main/resources/kv1.txt') - >>> supress = hiveCtx.hql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") - >>> supress = hiveCtx.hql("LOAD DATA LOCAL INPATH '%s' INTO TABLE src" % kv1) - >>> results = hiveCtx.hql("FROM src SELECT value").map(lambda r: int(r.value.split('_')[1])) + >>> kv1 = os.path.join(os.environ["SPARK_HOME"], + ... 'examples/src/main/resources/kv1.txt') + >>> supress = hiveCtx.sql( + ... "CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") + >>> supress = hiveCtx.sql("LOAD DATA LOCAL INPATH '%s' INTO TABLE src" + ... % kv1) + >>> results = hiveCtx.sql("FROM src SELECT value" + ... ).map(lambda r: int(r.value.split('_')[1])) >>> num = results.count() >>> reduce_sum = results.reduce(lambda x, y: x + y) >>> num @@ -276,6 +1347,11 @@ class LocalHiveContext(HiveContext): 130091 """ + def __init__(self, sparkContext, sqlContext=None): + HiveContext.__init__(self, sparkContext, sqlContext) + warnings.warn("LocalHiveContext is deprecated. " + "Use HiveContext instead.", DeprecationWarning) + def _get_hive_ctx(self): return self._jvm.LocalHiveContext(self._jsc.sc()) @@ -286,25 +1362,83 @@ def _get_hive_ctx(self): return self._jvm.TestHiveContext(self._jsc.sc()) -# TODO: Investigate if it is more efficient to use a namedtuple. One problem is that named tuples -# are custom classes that must be generated per Schema. -class Row(dict): - """A row in L{SchemaRDD}. +def _create_row(fields, values): + row = Row(*values) + row.__FIELDS__ = fields + return row - An extended L{dict} that takes a L{dict} in its constructor, and - exposes those items as fields. - >>> r = Row({"hello" : "world", "foo" : "bar"}) - >>> r.hello - 'world' - >>> r.foo - 'bar' +class Row(tuple): """ + A row in L{SchemaRDD}. The fields in it can be accessed like attributes. - def __init__(self, d): - d.update(self.__dict__) - self.__dict__ = d - dict.__init__(self, d) + Row can be used to create a row object by using named arguments, + the fields will be sorted by names. + + >>> row = Row(name="Alice", age=11) + >>> row + Row(age=11, name='Alice') + >>> row.name, row.age + ('Alice', 11) + + Row also can be used to create another Row like class, then it + could be used to create Row objects, such as + + >>> Person = Row("name", "age") + >>> Person + + >>> Person("Alice", 11) + Row(name='Alice', age=11) + """ + + def __new__(self, *args, **kwargs): + if args and kwargs: + raise ValueError("Can not use both args " + "and kwargs to create Row") + if args: + # create row class or objects + return tuple.__new__(self, args) + + elif kwargs: + # create row objects + names = sorted(kwargs.keys()) + values = tuple(kwargs[n] for n in names) + row = tuple.__new__(self, values) + row.__FIELDS__ = names + return row + + else: + raise ValueError("No args or kwargs") + + + # let obect acs like class + def __call__(self, *args): + """create new Row object""" + return _create_row(self, args) + + def __getattr__(self, item): + if item.startswith("__"): + raise AttributeError(item) + try: + # it will be slow when it has many fields, + # but this will not be used in normal cases + idx = self.__FIELDS__.index(item) + return self[idx] + except IndexError: + raise AttributeError(item) + + def __reduce__(self): + if hasattr(self, "__FIELDS__"): + return (_create_row, (self.__FIELDS__, tuple(self))) + else: + return tuple.__reduce__(self) + + def __repr__(self): + if hasattr(self, "__FIELDS__"): + return "Row(%s)" % ", ".join("%s=%r" % (k, v) + for k, v in zip(self.__FIELDS__, self)) + else: + return "" % ", ".join(self) class SchemaRDD(RDD): @@ -318,6 +1452,10 @@ class SchemaRDD(RDD): implementation is an RDD composed of Java objects. Instead it is converted to a PythonRDD in the JVM, on which Python operations can be done. + + This class receives raw tuples from Java but assigns a class to it in + all its data-collection methods (mapPartitionsWithIndex, collect, take, + etc) so that PySpark sees them as Row objects with named fields. """ def __init__(self, jschema_rdd, sql_ctx): @@ -328,7 +1466,8 @@ def __init__(self, jschema_rdd, sql_ctx): self.is_cached = False self.is_checkpointed = False self.ctx = self.sql_ctx._sc - self._jrdd_deserializer = self.ctx.serializer + # the _jrdd is created by javaToPython(), serialized by pickle + self._jrdd_deserializer = BatchedSerializer(PickleSerializer()) @property def _jrdd(self): @@ -338,7 +1477,7 @@ def _jrdd(self): L{pyspark.rdd.RDD} super class (map, filter, etc.). """ if not hasattr(self, '_lazy_jrdd'): - self._lazy_jrdd = self._toPython()._jrdd + self._lazy_jrdd = self._jschema_rdd.javaToPython() return self._lazy_jrdd @property @@ -362,19 +1501,23 @@ def saveAsParquetFile(self, path): """ self._jschema_rdd.saveAsParquetFile(path) - def registerAsTable(self, name): + def registerTempTable(self, name): """Registers this RDD as a temporary table using the given name. The lifetime of this temporary table is tied to the L{SQLContext} that was used to create this SchemaRDD. >>> srdd = sqlCtx.inferSchema(rdd) - >>> srdd.registerAsTable("test") + >>> srdd.registerTempTable("test") >>> srdd2 = sqlCtx.sql("select * from test") >>> sorted(srdd.collect()) == sorted(srdd2.collect()) True """ - self._jschema_rdd.registerAsTable(name) + self._jschema_rdd.registerTempTable(name) + + def registerAsTable(self, name): + warnings.warn("Use registerTempTable instead of registerAsTable.", DeprecationWarning) + self.registerTempTable(name) def insertInto(self, tableName, overwrite=False): """Inserts the contents of this SchemaRDD into the specified table. @@ -387,6 +1530,11 @@ def saveAsTable(self, tableName): """Creates a new table with the contents of this SchemaRDD.""" self._jschema_rdd.saveAsTable(tableName) + def schema(self): + """Returns the schema of this SchemaRDD (represented by + a L{StructType}).""" + return _parse_datatype_string(self._jschema_rdd.schema().toString()) + def schemaString(self): """Returns the output schema in the tree format.""" return self._jschema_rdd.schemaString() @@ -410,19 +1558,45 @@ def count(self): """ return self._jschema_rdd.count() - def _toPython(self): - # We have to import the Row class explicitly, so that the reference Pickler has is - # pyspark.sql.Row instead of __main__.Row - from pyspark.sql import Row - jrdd = self._jschema_rdd.javaToPython() - # TODO: This is inefficient, we should construct the Python Row object - # in Java land in the javaToPython function. May require a custom - # pickle serializer in Pyrolite - return RDD(jrdd, self._sc, BatchedSerializer( - PickleSerializer())).map(lambda d: Row(d)) - - # We override the default cache/persist/checkpoint behavior as we want to cache the underlying - # SchemaRDD object in the JVM, not the PythonRDD checkpointed by the super class + def collect(self): + """ + Return a list that contains all of the rows in this RDD. + + Each object in the list is on Row, the fields can be accessed as + attributes. + """ + rows = RDD.collect(self) + cls = _create_cls(self.schema()) + return map(cls, rows) + + # Convert each object in the RDD to a Row with the right class + # for this SchemaRDD, so that fields can be accessed as attributes. + def mapPartitionsWithIndex(self, f, preservesPartitioning=False): + """ + Return a new RDD by applying a function to each partition of this RDD, + while tracking the index of the original partition. + + >>> rdd = sc.parallelize([1, 2, 3, 4], 4) + >>> def f(splitIndex, iterator): yield splitIndex + >>> rdd.mapPartitionsWithIndex(f).sum() + 6 + """ + rdd = RDD(self._jrdd, self._sc, self._jrdd_deserializer) + + schema = self.schema() + import pickle + pickle.loads(pickle.dumps(schema)) + + def applySchema(_, it): + cls = _create_cls(schema) + return itertools.imap(cls, it) + + objrdd = rdd.mapPartitionsWithIndex(applySchema, preservesPartitioning) + return objrdd.mapPartitionsWithIndex(f, preservesPartitioning) + + # We override the default cache/persist/checkpoint behavior + # as we want to cache the underlying SchemaRDD object in the JVM, + # not the PythonRDD checkpointed by the super class def cache(self): self.is_cached = True self._jschema_rdd.cache() @@ -434,9 +1608,9 @@ def persist(self, storageLevel): self._jschema_rdd.persist(javaStorageLevel) return self - def unpersist(self): + def unpersist(self, blocking=True): self.is_cached = False - self._jschema_rdd.unpersist() + self._jschema_rdd.unpersist(blocking) return self def checkpoint(self): @@ -477,7 +1651,8 @@ def subtract(self, other, numPartitions=None): if numPartitions is None: rdd = self._jschema_rdd.subtract(other._jschema_rdd) else: - rdd = self._jschema_rdd.subtract(other._jschema_rdd, numPartitions) + rdd = self._jschema_rdd.subtract(other._jschema_rdd, + numPartitions) return SchemaRDD(rdd, self.sql_ctx) else: raise ValueError("Can only subtract another SchemaRDD") @@ -487,31 +1662,31 @@ def _test(): import doctest from array import array from pyspark.context import SparkContext - globs = globals().copy() + # let doctest run in pyspark.sql, so DataTypes can be picklable + import pyspark.sql + from pyspark.sql import Row, SQLContext + globs = pyspark.sql.__dict__.copy() # The small batch size here ensures that we see multiple batches, # even in these small test examples: sc = SparkContext('local[4]', 'PythonTest', batchSize=2) globs['sc'] = sc globs['sqlCtx'] = SQLContext(sc) globs['rdd'] = sc.parallelize( - [{"field1": 1, "field2": "row1"}, - {"field1": 2, "field2": "row2"}, - {"field1": 3, "field2": "row3"}] + [Row(field1=1, field2="row1"), + Row(field1=2, field2="row2"), + Row(field1=3, field2="row3")] ) jsonStrings = [ '{"field1": 1, "field2": "row1", "field3":{"field4":11}}', - '{"field1" : 2, "field3":{"field4":22, "field5": [10, 11]}, "field6":[{"field7": "row2"}]}', - '{"field1" : null, "field2": "row3", "field3":{"field4":33, "field5": []}}' + '{"field1" : 2, "field3":{"field4":22, "field5": [10, 11]},' + '"field6":[{"field7": "row2"}]}', + '{"field1" : null, "field2": "row3", ' + '"field3":{"field4":33, "field5": []}}' ] globs['jsonStrings'] = jsonStrings globs['json'] = sc.parallelize(jsonStrings) - globs['nestedRdd1'] = sc.parallelize([ - {"f1": array('i', [1, 2]), "f2": {"row1": 1.0}}, - {"f1": array('i', [2, 3]), "f2": {"row2": 2.0}}]) - globs['nestedRdd2'] = sc.parallelize([ - {"f1": [[1, 2], [2, 3]], "f2": set([1, 2]), "f3": (1, 2)}, - {"f1": [[2, 3], [3, 4]], "f2": set([2, 3]), "f3": (2, 3)}]) - (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) + (failure_count, test_count) = doctest.testmod( + pyspark.sql, globs=globs, optionflags=doctest.ELLIPSIS) globs['sc'].stop() if failure_count: exit(-1) diff --git a/python/pyspark/statcounter.py b/python/pyspark/statcounter.py index e287bd3da1f61..1e597d64e03fe 100644 --- a/python/pyspark/statcounter.py +++ b/python/pyspark/statcounter.py @@ -20,6 +20,13 @@ import copy import math +try: + from numpy import maximum, minimum, sqrt +except ImportError: + maximum = max + minimum = min + sqrt = math.sqrt + class StatCounter(object): @@ -39,10 +46,8 @@ def merge(self, value): self.n += 1 self.mu += delta / self.n self.m2 += delta * (value - self.mu) - if self.maxValue < value: - self.maxValue = value - if self.minValue > value: - self.minValue = value + self.maxValue = maximum(self.maxValue, value) + self.minValue = minimum(self.minValue, value) return self @@ -70,8 +75,8 @@ def mergeStats(self, other): else: self.mu = (self.mu * self.n + other.mu * other.n) / (self.n + other.n) - self.maxValue = max(self.maxValue, other.maxValue) - self.minValue = min(self.minValue, other.minValue) + self.maxValue = maximum(self.maxValue, other.maxValue) + self.minValue = minimum(self.minValue, other.minValue) self.m2 += other.m2 + (delta * delta * self.n * other.n) / (self.n + other.n) self.n += other.n @@ -115,14 +120,14 @@ def sampleVariance(self): # Return the standard deviation of the values. def stdev(self): - return math.sqrt(self.variance()) + return sqrt(self.variance()) # # Return the sample standard deviation of the values, which corrects for bias in estimating the # variance by dividing by N-1 instead of N. # def sampleStdev(self): - return math.sqrt(self.sampleVariance()) + return sqrt(self.sampleVariance()) def __repr__(self): return ("(count: %s, mean: %s, stdev: %s, max: %s, min: %s)" % diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 63cc5e9ad96fa..4ac94ba729d35 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -19,6 +19,7 @@ Unit tests for PySpark; additional tests are implemented as doctests in individual modules. """ +from array import array from fileinput import input from glob import glob import os @@ -37,12 +38,19 @@ from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger _have_scipy = False +_have_numpy = False try: import scipy.sparse _have_scipy = True except: # No SciPy, but that's okay, we'll skip those tests pass +try: + import numpy as np + _have_numpy = True +except: + # No NumPy, but that's okay, we'll skip those tests + pass SPARK_HOME = os.environ["SPARK_HOME"] @@ -104,6 +112,17 @@ def test_huge_dataset(self): m._cleanup() +class SerializationTestCase(unittest.TestCase): + + def test_namedtuple(self): + from collections import namedtuple + from cPickle import dumps, loads + P = namedtuple("P", "x y") + p1 = P(1, 3) + p2 = loads(dumps(p1, 2)) + self.assertEquals(p1, p2) + + class PySparkTestCase(unittest.TestCase): def setUp(self): @@ -165,11 +184,17 @@ class TestAddFile(PySparkTestCase): def test_add_py_file(self): # To ensure that we're actually testing addPyFile's effects, check that # this job fails due to `userlibrary` not being on the Python path: + # disable logging in log4j temporarily + log4j = self.sc._jvm.org.apache.log4j + old_level = log4j.LogManager.getRootLogger().getLevel() + log4j.LogManager.getRootLogger().setLevel(log4j.Level.FATAL) def func(x): from userlibrary import UserClass return UserClass().hello() self.assertRaises(Exception, self.sc.parallelize(range(2)).map(func).first) + log4j.LogManager.getRootLogger().setLevel(old_level) + # Add the file, so the job should now succeed: path = os.path.join(SPARK_HOME, "python/test_support/userlibrary.py") self.sc.addPyFile(path) @@ -278,6 +303,20 @@ def combOp(x, y): self.assertEqual(set([2]), sets[3]) self.assertEqual(set([1, 3]), sets[5]) + def test_itemgetter(self): + rdd = self.sc.parallelize([range(10)]) + from operator import itemgetter + self.assertEqual([1], rdd.map(itemgetter(1)).collect()) + self.assertEqual([(2, 3)], rdd.map(itemgetter(2, 3)).collect()) + + def test_namedtuple_in_rdd(self): + from collections import namedtuple + Person = namedtuple("Person", "id firstName lastName") + jon = Person(1, "Jon", "Doe") + jane = Person(2, "Jane", "Doe") + theDoes = self.sc.parallelize([jon, jane]) + self.assertEquals([jon, jane], theDoes.collect()) + class TestIO(PySparkTestCase): @@ -315,6 +354,17 @@ def test_sequencefiles(self): ed = [(1.0, u'aa'), (1.0, u'aa'), (2.0, u'aa'), (2.0, u'bb'), (2.0, u'bb'), (3.0, u'cc')] self.assertEqual(doubles, ed) + bytes = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfbytes/", + "org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.BytesWritable").collect()) + ebs = [(1, bytearray('aa', 'utf-8')), + (1, bytearray('aa', 'utf-8')), + (2, bytearray('aa', 'utf-8')), + (2, bytearray('bb', 'utf-8')), + (2, bytearray('bb', 'utf-8')), + (3, bytearray('cc', 'utf-8'))] + self.assertEqual(bytes, ebs) + text = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sftext/", "org.apache.hadoop.io.Text", "org.apache.hadoop.io.Text").collect()) @@ -341,14 +391,34 @@ def test_sequencefiles(self): maps = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfmap/", "org.apache.hadoop.io.IntWritable", "org.apache.hadoop.io.MapWritable").collect()) - em = [(1, {2.0: u'aa'}), + em = [(1, {}), (1, {3.0: u'bb'}), (2, {1.0: u'aa'}), (2, {1.0: u'cc'}), - (2, {3.0: u'bb'}), (3, {2.0: u'dd'})] self.assertEqual(maps, em) + # arrays get pickled to tuples by default + tuples = sorted(self.sc.sequenceFile( + basepath + "/sftestdata/sfarray/", + "org.apache.hadoop.io.IntWritable", + "org.apache.spark.api.python.DoubleArrayWritable").collect()) + et = [(1, ()), + (2, (3.0, 4.0, 5.0)), + (3, (4.0, 5.0, 6.0))] + self.assertEqual(tuples, et) + + # with custom converters, primitive arrays can stay as arrays + arrays = sorted(self.sc.sequenceFile( + basepath + "/sftestdata/sfarray/", + "org.apache.hadoop.io.IntWritable", + "org.apache.spark.api.python.DoubleArrayWritable", + valueConverter="org.apache.spark.api.python.WritableToDoubleArrayConverter").collect()) + ea = [(1, array('d')), + (2, array('d', [3.0, 4.0, 5.0])), + (3, array('d', [4.0, 5.0, 6.0]))] + self.assertEqual(arrays, ea) + clazz = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfclass/", "org.apache.hadoop.io.Text", "org.apache.spark.api.python.TestWritable").collect()) @@ -357,6 +427,12 @@ def test_sequencefiles(self): u'double': 54.0, u'int': 123, u'str': u'test1'}) self.assertEqual(clazz[0], ec) + unbatched_clazz = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfclass/", + "org.apache.hadoop.io.Text", + "org.apache.spark.api.python.TestWritable", + batchSize=1).collect()) + self.assertEqual(unbatched_clazz[0], ec) + def test_oldhadoop(self): basepath = self.tempdir.name ints = sorted(self.sc.hadoopFile(basepath + "/sftestdata/sfint/", @@ -367,10 +443,11 @@ def test_oldhadoop(self): self.assertEqual(ints, ei) hellopath = os.path.join(SPARK_HOME, "python/test_support/hello.txt") - hello = self.sc.hadoopFile(hellopath, - "org.apache.hadoop.mapred.TextInputFormat", - "org.apache.hadoop.io.LongWritable", - "org.apache.hadoop.io.Text").collect() + oldconf = {"mapred.input.dir" : hellopath} + hello = self.sc.hadoopRDD("org.apache.hadoop.mapred.TextInputFormat", + "org.apache.hadoop.io.LongWritable", + "org.apache.hadoop.io.Text", + conf=oldconf).collect() result = [(0, u'Hello World!')] self.assertEqual(hello, result) @@ -385,10 +462,11 @@ def test_newhadoop(self): self.assertEqual(ints, ei) hellopath = os.path.join(SPARK_HOME, "python/test_support/hello.txt") - hello = self.sc.newAPIHadoopFile(hellopath, - "org.apache.hadoop.mapreduce.lib.input.TextInputFormat", - "org.apache.hadoop.io.LongWritable", - "org.apache.hadoop.io.Text").collect() + newconf = {"mapred.input.dir" : hellopath} + hello = self.sc.newAPIHadoopRDD("org.apache.hadoop.mapreduce.lib.input.TextInputFormat", + "org.apache.hadoop.io.LongWritable", + "org.apache.hadoop.io.Text", + conf=newconf).collect() result = [(0, u'Hello World!')] self.assertEqual(hello, result) @@ -423,16 +501,267 @@ def test_bad_inputs(self): "org.apache.hadoop.io.IntWritable", "org.apache.hadoop.io.Text")) - def test_converter(self): + def test_converters(self): + # use of custom converters basepath = self.tempdir.name maps = sorted(self.sc.sequenceFile( basepath + "/sftestdata/sfmap/", "org.apache.hadoop.io.IntWritable", "org.apache.hadoop.io.MapWritable", - valueConverter="org.apache.spark.api.python.TestConverter").collect()) - em = [(1, [2.0]), (1, [3.0]), (2, [1.0]), (2, [1.0]), (2, [3.0]), (3, [2.0])] + keyConverter="org.apache.spark.api.python.TestInputKeyConverter", + valueConverter="org.apache.spark.api.python.TestInputValueConverter").collect()) + em = [(u'\x01', []), + (u'\x01', [3.0]), + (u'\x02', [1.0]), + (u'\x02', [1.0]), + (u'\x03', [2.0])] self.assertEqual(maps, em) +class TestOutputFormat(PySparkTestCase): + + def setUp(self): + PySparkTestCase.setUp(self) + self.tempdir = tempfile.NamedTemporaryFile(delete=False) + os.unlink(self.tempdir.name) + + def tearDown(self): + PySparkTestCase.tearDown(self) + shutil.rmtree(self.tempdir.name, ignore_errors=True) + + def test_sequencefiles(self): + basepath = self.tempdir.name + ei = [(1, u'aa'), (1, u'aa'), (2, u'aa'), (2, u'bb'), (2, u'bb'), (3, u'cc')] + self.sc.parallelize(ei).saveAsSequenceFile(basepath + "/sfint/") + ints = sorted(self.sc.sequenceFile(basepath + "/sfint/").collect()) + self.assertEqual(ints, ei) + + ed = [(1.0, u'aa'), (1.0, u'aa'), (2.0, u'aa'), (2.0, u'bb'), (2.0, u'bb'), (3.0, u'cc')] + self.sc.parallelize(ed).saveAsSequenceFile(basepath + "/sfdouble/") + doubles = sorted(self.sc.sequenceFile(basepath + "/sfdouble/").collect()) + self.assertEqual(doubles, ed) + + ebs = [(1, bytearray(b'\x00\x07spam\x08')), (2, bytearray(b'\x00\x07spam\x08'))] + self.sc.parallelize(ebs).saveAsSequenceFile(basepath + "/sfbytes/") + bytes = sorted(self.sc.sequenceFile(basepath + "/sfbytes/").collect()) + self.assertEqual(bytes, ebs) + + et = [(u'1', u'aa'), + (u'2', u'bb'), + (u'3', u'cc')] + self.sc.parallelize(et).saveAsSequenceFile(basepath + "/sftext/") + text = sorted(self.sc.sequenceFile(basepath + "/sftext/").collect()) + self.assertEqual(text, et) + + eb = [(1, False), (1, True), (2, False), (2, False), (2, True), (3, True)] + self.sc.parallelize(eb).saveAsSequenceFile(basepath + "/sfbool/") + bools = sorted(self.sc.sequenceFile(basepath + "/sfbool/").collect()) + self.assertEqual(bools, eb) + + en = [(1, None), (1, None), (2, None), (2, None), (2, None), (3, None)] + self.sc.parallelize(en).saveAsSequenceFile(basepath + "/sfnull/") + nulls = sorted(self.sc.sequenceFile(basepath + "/sfnull/").collect()) + self.assertEqual(nulls, en) + + em = [(1, {}), + (1, {3.0: u'bb'}), + (2, {1.0: u'aa'}), + (2, {1.0: u'cc'}), + (3, {2.0: u'dd'})] + self.sc.parallelize(em).saveAsSequenceFile(basepath + "/sfmap/") + maps = sorted(self.sc.sequenceFile(basepath + "/sfmap/").collect()) + self.assertEqual(maps, em) + + def test_oldhadoop(self): + basepath = self.tempdir.name + dict_data = [(1, {}), + (1, {"row1" : 1.0}), + (2, {"row2" : 2.0})] + self.sc.parallelize(dict_data).saveAsHadoopFile( + basepath + "/oldhadoop/", + "org.apache.hadoop.mapred.SequenceFileOutputFormat", + "org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.MapWritable") + result = sorted(self.sc.hadoopFile( + basepath + "/oldhadoop/", + "org.apache.hadoop.mapred.SequenceFileInputFormat", + "org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.MapWritable").collect()) + self.assertEqual(result, dict_data) + + conf = { + "mapred.output.format.class" : "org.apache.hadoop.mapred.SequenceFileOutputFormat", + "mapred.output.key.class" : "org.apache.hadoop.io.IntWritable", + "mapred.output.value.class" : "org.apache.hadoop.io.MapWritable", + "mapred.output.dir" : basepath + "/olddataset/"} + self.sc.parallelize(dict_data).saveAsHadoopDataset(conf) + input_conf = {"mapred.input.dir" : basepath + "/olddataset/"} + old_dataset = sorted(self.sc.hadoopRDD( + "org.apache.hadoop.mapred.SequenceFileInputFormat", + "org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.MapWritable", + conf=input_conf).collect()) + self.assertEqual(old_dataset, dict_data) + + def test_newhadoop(self): + basepath = self.tempdir.name + # use custom ArrayWritable types and converters to handle arrays + array_data = [(1, array('d')), + (1, array('d', [1.0, 2.0, 3.0])), + (2, array('d', [3.0, 4.0, 5.0]))] + self.sc.parallelize(array_data).saveAsNewAPIHadoopFile( + basepath + "/newhadoop/", + "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat", + "org.apache.hadoop.io.IntWritable", + "org.apache.spark.api.python.DoubleArrayWritable", + valueConverter="org.apache.spark.api.python.DoubleArrayToWritableConverter") + result = sorted(self.sc.newAPIHadoopFile( + basepath + "/newhadoop/", + "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat", + "org.apache.hadoop.io.IntWritable", + "org.apache.spark.api.python.DoubleArrayWritable", + valueConverter="org.apache.spark.api.python.WritableToDoubleArrayConverter").collect()) + self.assertEqual(result, array_data) + + conf = {"mapreduce.outputformat.class" : + "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat", + "mapred.output.key.class" : "org.apache.hadoop.io.IntWritable", + "mapred.output.value.class" : "org.apache.spark.api.python.DoubleArrayWritable", + "mapred.output.dir" : basepath + "/newdataset/"} + self.sc.parallelize(array_data).saveAsNewAPIHadoopDataset(conf, + valueConverter="org.apache.spark.api.python.DoubleArrayToWritableConverter") + input_conf = {"mapred.input.dir" : basepath + "/newdataset/"} + new_dataset = sorted(self.sc.newAPIHadoopRDD( + "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat", + "org.apache.hadoop.io.IntWritable", + "org.apache.spark.api.python.DoubleArrayWritable", + valueConverter="org.apache.spark.api.python.WritableToDoubleArrayConverter", + conf=input_conf).collect()) + self.assertEqual(new_dataset, array_data) + + def test_newolderror(self): + basepath = self.tempdir.name + rdd = self.sc.parallelize(range(1, 4)).map(lambda x: (x, "a" * x )) + self.assertRaises(Exception, lambda: rdd.saveAsHadoopFile( + basepath + "/newolderror/saveAsHadoopFile/", + "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat")) + self.assertRaises(Exception, lambda: rdd.saveAsNewAPIHadoopFile( + basepath + "/newolderror/saveAsNewAPIHadoopFile/", + "org.apache.hadoop.mapred.SequenceFileOutputFormat")) + + def test_bad_inputs(self): + basepath = self.tempdir.name + rdd = self.sc.parallelize(range(1, 4)).map(lambda x: (x, "a" * x )) + self.assertRaises(Exception, lambda: rdd.saveAsHadoopFile( + basepath + "/badinputs/saveAsHadoopFile/", + "org.apache.hadoop.mapred.NotValidOutputFormat")) + self.assertRaises(Exception, lambda: rdd.saveAsNewAPIHadoopFile( + basepath + "/badinputs/saveAsNewAPIHadoopFile/", + "org.apache.hadoop.mapreduce.lib.output.NotValidOutputFormat")) + + def test_converters(self): + # use of custom converters + basepath = self.tempdir.name + data = [(1, {3.0: u'bb'}), + (2, {1.0: u'aa'}), + (3, {2.0: u'dd'})] + self.sc.parallelize(data).saveAsNewAPIHadoopFile( + basepath + "/converters/", + "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat", + keyConverter="org.apache.spark.api.python.TestOutputKeyConverter", + valueConverter="org.apache.spark.api.python.TestOutputValueConverter") + converted = sorted(self.sc.sequenceFile(basepath + "/converters/").collect()) + expected = [(u'1', 3.0), + (u'2', 1.0), + (u'3', 2.0)] + self.assertEqual(converted, expected) + + def test_reserialization(self): + basepath = self.tempdir.name + x = range(1, 5) + y = range(1001, 1005) + data = zip(x, y) + rdd = self.sc.parallelize(x).zip(self.sc.parallelize(y)) + rdd.saveAsSequenceFile(basepath + "/reserialize/sequence") + result1 = sorted(self.sc.sequenceFile(basepath + "/reserialize/sequence").collect()) + self.assertEqual(result1, data) + + rdd.saveAsHadoopFile(basepath + "/reserialize/hadoop", + "org.apache.hadoop.mapred.SequenceFileOutputFormat") + result2 = sorted(self.sc.sequenceFile(basepath + "/reserialize/hadoop").collect()) + self.assertEqual(result2, data) + + rdd.saveAsNewAPIHadoopFile(basepath + "/reserialize/newhadoop", + "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat") + result3 = sorted(self.sc.sequenceFile(basepath + "/reserialize/newhadoop").collect()) + self.assertEqual(result3, data) + + conf4 = { + "mapred.output.format.class" : "org.apache.hadoop.mapred.SequenceFileOutputFormat", + "mapred.output.key.class" : "org.apache.hadoop.io.IntWritable", + "mapred.output.value.class" : "org.apache.hadoop.io.IntWritable", + "mapred.output.dir" : basepath + "/reserialize/dataset"} + rdd.saveAsHadoopDataset(conf4) + result4 = sorted(self.sc.sequenceFile(basepath + "/reserialize/dataset").collect()) + self.assertEqual(result4, data) + + conf5 = {"mapreduce.outputformat.class" : + "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat", + "mapred.output.key.class" : "org.apache.hadoop.io.IntWritable", + "mapred.output.value.class" : "org.apache.hadoop.io.IntWritable", + "mapred.output.dir" : basepath + "/reserialize/newdataset"} + rdd.saveAsNewAPIHadoopDataset(conf5) + result5 = sorted(self.sc.sequenceFile(basepath + "/reserialize/newdataset").collect()) + self.assertEqual(result5, data) + + def test_unbatched_save_and_read(self): + basepath = self.tempdir.name + ei = [(1, u'aa'), (1, u'aa'), (2, u'aa'), (2, u'bb'), (2, u'bb'), (3, u'cc')] + self.sc.parallelize(ei, numSlices=len(ei)).saveAsSequenceFile( + basepath + "/unbatched/") + + unbatched_sequence = sorted(self.sc.sequenceFile(basepath + "/unbatched/", + batchSize=1).collect()) + self.assertEqual(unbatched_sequence, ei) + + unbatched_hadoopFile = sorted(self.sc.hadoopFile(basepath + "/unbatched/", + "org.apache.hadoop.mapred.SequenceFileInputFormat", + "org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.Text", + batchSize=1).collect()) + self.assertEqual(unbatched_hadoopFile, ei) + + unbatched_newAPIHadoopFile = sorted(self.sc.newAPIHadoopFile(basepath + "/unbatched/", + "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat", + "org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.Text", + batchSize=1).collect()) + self.assertEqual(unbatched_newAPIHadoopFile, ei) + + oldconf = {"mapred.input.dir" : basepath + "/unbatched/"} + unbatched_hadoopRDD = sorted(self.sc.hadoopRDD( + "org.apache.hadoop.mapred.SequenceFileInputFormat", + "org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.Text", + conf=oldconf, + batchSize=1).collect()) + self.assertEqual(unbatched_hadoopRDD, ei) + + newconf = {"mapred.input.dir" : basepath + "/unbatched/"} + unbatched_newAPIHadoopRDD = sorted(self.sc.newAPIHadoopRDD( + "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat", + "org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.Text", + conf=newconf, + batchSize=1).collect()) + self.assertEqual(unbatched_newAPIHadoopRDD, ei) + + def test_malformed_RDD(self): + basepath = self.tempdir.name + # non-batch-serialized RDD[[(K, V)]] should be rejected + data = [[(1, "a")], [(2, "aa")], [(3, "aaa")]] + rdd = self.sc.parallelize(data, numSlices=len(data)) + self.assertRaises(Exception, lambda: rdd.saveAsSequenceFile( + basepath + "/malformed/sequence")) class TestDaemon(unittest.TestCase): def connect(self, port): @@ -480,6 +809,57 @@ def test_termination_sigterm(self): self.do_termination_test(lambda daemon: os.kill(daemon.pid, SIGTERM)) +class TestWorker(PySparkTestCase): + def test_cancel_task(self): + temp = tempfile.NamedTemporaryFile(delete=True) + temp.close() + path = temp.name + def sleep(x): + import os, time + with open(path, 'w') as f: + f.write("%d %d" % (os.getppid(), os.getpid())) + time.sleep(100) + + # start job in background thread + def run(): + self.sc.parallelize(range(1)).foreach(sleep) + import threading + t = threading.Thread(target=run) + t.daemon = True + t.start() + + daemon_pid, worker_pid = 0, 0 + while True: + if os.path.exists(path): + data = open(path).read().split(' ') + daemon_pid, worker_pid = map(int, data) + break + time.sleep(0.1) + + # cancel jobs + self.sc.cancelAllJobs() + t.join() + + for i in range(50): + try: + os.kill(worker_pid, 0) + time.sleep(0.1) + except OSError: + break # worker was killed + else: + self.fail("worker has not been killed after 5 seconds") + + try: + os.kill(daemon_pid, 0) + except OSError: + self.fail("daemon had been killed") + + def test_fd_leak(self): + N = 1100 # fd limit is 1024 by default + rdd = self.sc.parallelize(range(N), N) + self.assertEquals(N, rdd.count()) + + class TestSparkSubmit(unittest.TestCase): def setUp(self): self.programDir = tempfile.mkdtemp() @@ -611,9 +991,26 @@ def test_serialize(self): self.assertEqual(expected, observed) +@unittest.skipIf(not _have_numpy, "NumPy not installed") +class NumPyTests(PySparkTestCase): + """General PySpark tests that depend on numpy """ + + def test_statcounter_array(self): + x = self.sc.parallelize([np.array([1.0,1.0]), np.array([2.0,2.0]), np.array([3.0,3.0])]) + s = x.stats() + self.assertSequenceEqual([2.0,2.0], s.mean().tolist()) + self.assertSequenceEqual([1.0,1.0], s.min().tolist()) + self.assertSequenceEqual([3.0,3.0], s.max().tolist()) + self.assertSequenceEqual([1.0,1.0], s.sampleStdev().tolist()) + + if __name__ == "__main__": if not _have_scipy: print "NOTE: Skipping SciPy tests as it does not seem to be installed" + if not _have_numpy: + print "NOTE: Skipping NumPy tests as it does not seem to be installed" unittest.main() if not _have_scipy: print "NOTE: SciPy tests were skipped as it does not seem to be installed" + if not _have_numpy: + print "NOTE: NumPy tests were skipped as it does not seem to be installed" diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 24d41b12d1b1a..2770f63059853 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -75,14 +75,19 @@ def main(infile, outfile): init_time = time.time() iterator = deserializer.load_stream(infile) serializer.dump_stream(func(split_index, iterator), outfile) - except Exception as e: - # Write the error to stderr in addition to trying to pass it back to - # Java, in case it happened while serializing a record - print >> sys.stderr, "PySpark worker failed with exception:" - print >> sys.stderr, traceback.format_exc() - write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile) - write_with_length(traceback.format_exc(), outfile) - sys.exit(-1) + except Exception: + try: + write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile) + write_with_length(traceback.format_exc(), outfile) + outfile.flush() + except IOError: + # JVM close the socket + pass + except Exception: + # Write the error to stderr if it happened while serializing + print >> sys.stderr, "PySpark worker failed with exception:" + print >> sys.stderr, traceback.format_exc() + exit(-1) finish_time = time.time() report_times(outfile, boot_time, init_time, finish_time) # Mark the beginning of the accumulators section of the output diff --git a/python/run-tests b/python/run-tests index 29f755fc0dcd3..48feba2f5bd63 100755 --- a/python/run-tests +++ b/python/run-tests @@ -67,9 +67,11 @@ run_test "pyspark/mllib/_common.py" run_test "pyspark/mllib/classification.py" run_test "pyspark/mllib/clustering.py" run_test "pyspark/mllib/linalg.py" +run_test "pyspark/mllib/random.py" run_test "pyspark/mllib/recommendation.py" run_test "pyspark/mllib/regression.py" run_test "pyspark/mllib/tests.py" +run_test "pyspark/mllib/util.py" if [[ $FAILED == 0 ]]; then echo -en "\033[32m" # Green diff --git a/repl/pom.xml b/repl/pom.xml index 4ebb1b82f0e8c..68f4504450778 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -55,6 +55,12 @@ ${project.version} runtime + + org.apache.spark + spark-sql_${scala.binary.version} + ${project.version} + test + org.eclipse.jetty jetty-server diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala index e1db4d5395ab9..65788f4646d91 100644 --- a/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala +++ b/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala @@ -230,6 +230,20 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, case xs => xs find (_.name == cmd) } } + private var fallbackMode = false + + private def toggleFallbackMode() { + val old = fallbackMode + fallbackMode = !old + System.setProperty("spark.repl.fallback", fallbackMode.toString) + echo(s""" + |Switched ${if (old) "off" else "on"} fallback mode without restarting. + | If you have defined classes in the repl, it would + |be good to redefine them incase you plan to use them. If you still run + |into issues it would be good to restart the repl and turn on `:fallback` + |mode as first command. + """.stripMargin) + } /** Show the history */ lazy val historyCommand = new LoopCommand("history", "show the history (optional num is commands to show)") { @@ -299,6 +313,9 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, nullary("reset", "reset the repl to its initial state, forgetting all session entries", resetCommand), shCommand, nullary("silent", "disable/enable automatic printing of results", verbosity), + nullary("fallback", """ + |disable/enable advanced repl changes, these fix some issues but may introduce others. + |This mode will be removed once these fixes stablize""".stripMargin, toggleFallbackMode), cmd("type", "[-v] ", "display the type of an expression without evaluating it", typeCommand), nullary("warnings", "show the suppressed warnings from the most recent line which had any", warningsCommand) ) @@ -557,29 +574,27 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, if (isReplPower) powerCommands else Nil )*/ - val replayQuestionMessage = + private val replayQuestionMessage = """|That entry seems to have slain the compiler. Shall I replay |your session? I can re-run each line except the last one. |[y/n] """.trim.stripMargin - private val crashRecovery: PartialFunction[Throwable, Boolean] = { - case ex: Throwable => - echo(intp.global.throwableAsString(ex)) - - ex match { - case _: NoSuchMethodError | _: NoClassDefFoundError => - echo("\nUnrecoverable error.") - throw ex - case _ => - def fn(): Boolean = - try in.readYesOrNo(replayQuestionMessage, { echo("\nYou must enter y or n.") ; fn() }) - catch { case _: RuntimeException => false } - - if (fn()) replay() - else echo("\nAbandoning crashed session.") - } - true + private def crashRecovery(ex: Throwable): Boolean = { + echo(ex.toString) + ex match { + case _: NoSuchMethodError | _: NoClassDefFoundError => + echo("\nUnrecoverable error.") + throw ex + case _ => + def fn(): Boolean = + try in.readYesOrNo(replayQuestionMessage, { echo("\nYou must enter y or n.") ; fn() }) + catch { case _: RuntimeException => false } + + if (fn()) replay() + else echo("\nAbandoning crashed session.") + } + true } /** The main read-eval-print loop for the repl. It calls @@ -605,7 +620,10 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, } } def innerLoop() { - if ( try processLine(readOneLine()) catch crashRecovery ) + val shouldContinue = try { + processLine(readOneLine()) + } catch {case t: Throwable => crashRecovery(t)} + if (shouldContinue) innerLoop() } innerLoop() @@ -951,9 +969,6 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, if (execUri != null) { conf.set("spark.executor.uri", execUri) } - if (System.getenv("SPARK_HOME") != null) { - conf.setSparkHome(System.getenv("SPARK_HOME")) - } sparkContext = new SparkContext(conf) logInfo("Created spark context..") sparkContext diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala b/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala index 3842c291d0b7b..84b57cd2dc1af 100644 --- a/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala +++ b/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala @@ -102,7 +102,8 @@ import org.apache.spark.util.Utils val virtualDirectory = new PlainFile(outputDir) // "directory" for classfiles /** Jetty server that will serve our classes to worker nodes */ - val classServer = new HttpServer(outputDir, new SecurityManager(conf)) + val classServerPort = conf.getInt("spark.replClassServer.port", 0) + val classServer = new HttpServer(outputDir, new SecurityManager(conf), classServerPort, "HTTP class server") private var currentSettings: Settings = initialSettings var printResults = true // whether to print result lines var totalSilence = false // whether to print anything @@ -892,11 +893,16 @@ import org.apache.spark.util.Utils def definedTypeSymbol(name: String) = definedSymbols(newTypeName(name)) def definedTermSymbol(name: String) = definedSymbols(newTermName(name)) + val definedClasses = handlers.exists { + case _: ClassHandler => true + case _ => false + } + /** Code to import bound names from previous lines - accessPath is code to * append to objectName to access anything bound by request. */ val SparkComputedImports(importsPreamble, importsTrailer, accessPath) = - importsCode(referencedNames.toSet) + importsCode(referencedNames.toSet, definedClasses) /** Code to access a variable with the specified name */ def fullPath(vname: String) = { diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkImports.scala b/repl/src/main/scala/org/apache/spark/repl/SparkImports.scala index 9099e052f5796..193a42dcded12 100644 --- a/repl/src/main/scala/org/apache/spark/repl/SparkImports.scala +++ b/repl/src/main/scala/org/apache/spark/repl/SparkImports.scala @@ -108,8 +108,9 @@ trait SparkImports { * last one imported is actually usable. */ case class SparkComputedImports(prepend: String, append: String, access: String) + def fallback = System.getProperty("spark.repl.fallback", "false").toBoolean - protected def importsCode(wanted: Set[Name]): SparkComputedImports = { + protected def importsCode(wanted: Set[Name], definedClass: Boolean): SparkComputedImports = { /** Narrow down the list of requests from which imports * should be taken. Removes requests which cannot contribute * useful imports for the specified set of wanted names. @@ -124,8 +125,14 @@ trait SparkImports { // Single symbol imports might be implicits! See bug #1752. Rather than // try to finesse this, we will mimic all imports for now. def keepHandler(handler: MemberHandler) = handler match { - case _: ImportHandler => true - case x => x.definesImplicit || (x.definedNames exists wanted) + /* This case clause tries to "precisely" import only what is required. And in this + * it may miss out on some implicits, because implicits are not known in `wanted`. Thus + * it is suitable for defining classes. AFAIK while defining classes implicits are not + * needed.*/ + case h: ImportHandler if definedClass && !fallback => + h.importedNames.exists(x => wanted.contains(x)) + case _: ImportHandler => true + case x => x.definesImplicit || (x.definedNames exists wanted) } reqs match { @@ -182,7 +189,7 @@ trait SparkImports { // ambiguity errors will not be generated. Also, quote // the name of the variable, so that we don't need to // handle quoting keywords separately. - case x: ClassHandler => + case x: ClassHandler if !fallback => // I am trying to guess if the import is a defined class // This is an ugly hack, I am not 100% sure of the consequences. // Here we, let everything but "defined classes" use the import with val. diff --git a/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala index e2d8d5ff38dbe..c8763eb277052 100644 --- a/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala +++ b/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala @@ -256,6 +256,33 @@ class ReplSuite extends FunSuite { assertDoesNotContain("error:", output) assertDoesNotContain("Exception", output) } + + test("SPARK-2576 importing SQLContext.createSchemaRDD.") { + // We need to use local-cluster to test this case. + val output = runInterpreter("local-cluster[1,1,512]", + """ + |val sqlContext = new org.apache.spark.sql.SQLContext(sc) + |import sqlContext.createSchemaRDD + |case class TestCaseClass(value: Int) + |sc.parallelize(1 to 10).map(x => TestCaseClass(x)).toSchemaRDD.collect + """.stripMargin) + assertDoesNotContain("error:", output) + assertDoesNotContain("Exception", output) + } + + test("SPARK-2632 importing a method from non serializable class and not using it.") { + val output = runInterpreter("local", + """ + |class TestClass() { def testMethod = 3 } + |val t = new TestClass + |import t.testMethod + |case class TestCaseClass(value: Int) + |sc.parallelize(1 to 10).map(x => TestCaseClass(x)).collect + """.stripMargin) + assertDoesNotContain("error:", output) + assertDoesNotContain("Exception", output) + } + if (System.getenv("MESOS_NATIVE_LIBRARY") != null) { test("running on Mesos") { val output = runInterpreter("localquiet", diff --git a/sbin/spark-config.sh b/sbin/spark-config.sh index 147b506dd5ca3..5c87da5815b64 100755 --- a/sbin/spark-config.sh +++ b/sbin/spark-config.sh @@ -36,4 +36,4 @@ export SPARK_HOME=${SPARK_PREFIX} export SPARK_CONF_DIR="$SPARK_HOME/conf" # Add the PySpark classes to the PYTHONPATH: export PYTHONPATH=$SPARK_HOME/python:$PYTHONPATH -export PYTHONPATH=$SPARK_HOME/python/lib/py4j-0.8.1-src.zip:$PYTHONPATH +export PYTHONPATH=$SPARK_HOME/python/lib/py4j-0.8.2.1-src.zip:$PYTHONPATH diff --git a/sbin/spark-executor b/sbin/spark-executor index 336549f29c9ce..3621321a9bc8d 100755 --- a/sbin/spark-executor +++ b/sbin/spark-executor @@ -20,7 +20,7 @@ FWDIR="$(cd `dirname $0`/..; pwd)" export PYTHONPATH=$FWDIR/python:$PYTHONPATH -export PYTHONPATH=$FWDIR/python/lib/py4j-0.8.1-src.zip:$PYTHONPATH +export PYTHONPATH=$FWDIR/python/lib/py4j-0.8.2.1-src.zip:$PYTHONPATH echo "Running spark-executor with framework dir = $FWDIR" exec $FWDIR/bin/spark-class org.apache.spark.executor.MesosExecutorBackend diff --git a/sbin/start-thriftserver.sh b/sbin/start-thriftserver.sh new file mode 100755 index 0000000000000..8398e6f19b511 --- /dev/null +++ b/sbin/start-thriftserver.sh @@ -0,0 +1,36 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# +# Shell script for starting the Spark SQL Thrift server + +# Enter posix mode for bash +set -o posix + +# Figure out where Spark is installed +FWDIR="$(cd `dirname $0`/..; pwd)" + +if [[ "$@" = *--help ]] || [[ "$@" = *-h ]]; then + echo "Usage: ./sbin/start-thriftserver [options]" + $FWDIR/bin/spark-submit --help 2>&1 | grep -v Usage 1>&2 + exit 0 +fi + +CLASS="org.apache.spark.sql.hive.thriftserver.HiveThriftServer2" +exec "$FWDIR"/bin/spark-submit --class $CLASS spark-internal $@ diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index 6decde3fcd62d..58d44e7923bee 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -32,24 +32,28 @@ Spark Project Catalyst http://spark.apache.org/ - catalyst + catalyst + + org.scala-lang + scala-compiler + org.scala-lang scala-reflect + + org.scalamacros + quasiquotes_${scala.binary.version} + ${scala.macros.version} + org.apache.spark spark-core_${scala.binary.version} ${project.version} - - com.typesafe - scalalogging-slf4j_${scala.binary.version} - 1.0.1 - org.scalatest scalatest_${scala.binary.version} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 5a55be1e51558..0d26b52a84695 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -85,6 +85,26 @@ object ScalaReflection { case t if t <:< definitions.BooleanTpe => Schema(BooleanType, nullable = false) } + def typeOfObject: PartialFunction[Any, DataType] = { + // The data type can be determined without ambiguity. + case obj: BooleanType.JvmType => BooleanType + case obj: BinaryType.JvmType => BinaryType + case obj: StringType.JvmType => StringType + case obj: ByteType.JvmType => ByteType + case obj: ShortType.JvmType => ShortType + case obj: IntegerType.JvmType => IntegerType + case obj: LongType.JvmType => LongType + case obj: FloatType.JvmType => FloatType + case obj: DoubleType.JvmType => DoubleType + case obj: DecimalType.JvmType => DecimalType + case obj: TimestampType.JvmType => TimestampType + case null => NullType + // For other cases, there is no obvious mapping from the type of the given object to a + // Catalyst data type. A user should provide his/her specific rules + // (in a user-defined PartialFunction) to infer the Catalyst data type for other types of + // objects and then compose the user-defined PartialFunction with this one. + } + implicit class CaseClassRelation[A <: Product : TypeTag](data: Seq[A]) { /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index a34b236c8ac6a..2c73a80f64ebf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -210,21 +210,21 @@ class SqlParser extends StandardTokenParsers with PackratParsers { } | "(" ~> query ~ ")" ~ opt(AS) ~ ident ^^ { case s ~ _ ~ _ ~ a => Subquery(a, s) } - protected lazy val joinedRelation: Parser[LogicalPlan] = - relationFactor ~ opt(joinType) ~ JOIN ~ relationFactor ~ opt(joinConditions) ^^ { + protected lazy val joinedRelation: Parser[LogicalPlan] = + relationFactor ~ opt(joinType) ~ JOIN ~ relationFactor ~ opt(joinConditions) ^^ { case r1 ~ jt ~ _ ~ r2 ~ cond => Join(r1, r2, joinType = jt.getOrElse(Inner), cond) - } + } - protected lazy val joinConditions: Parser[Expression] = - ON ~> expression + protected lazy val joinConditions: Parser[Expression] = + ON ~> expression - protected lazy val joinType: Parser[JoinType] = - INNER ^^^ Inner | - LEFT ~ SEMI ^^^ LeftSemi | - LEFT ~ opt(OUTER) ^^^ LeftOuter | - RIGHT ~ opt(OUTER) ^^^ RightOuter | - FULL ~ opt(OUTER) ^^^ FullOuter + protected lazy val joinType: Parser[JoinType] = + INNER ^^^ Inner | + LEFT ~ SEMI ^^^ LeftSemi | + LEFT ~ opt(OUTER) ^^^ LeftOuter | + RIGHT ~ opt(OUTER) ^^^ RightOuter | + FULL ~ opt(OUTER) ^^^ FullOuter protected lazy val filter: Parser[Expression] = WHERE ~ expression ^^ { case _ ~ e => e } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 02bdb64f308a5..c18d7858f0a43 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -48,6 +48,7 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool Batch("Resolution", fixedPoint, ResolveReferences :: ResolveRelations :: + ResolveSortReferences :: NewRelationInstances :: ImplicitGenerate :: StarExpansion :: @@ -109,17 +110,62 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool object ResolveReferences extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case q: LogicalPlan if q.childrenResolved => - logger.trace(s"Attempting to resolve ${q.simpleString}") + logTrace(s"Attempting to resolve ${q.simpleString}") q transformExpressions { case u @ UnresolvedAttribute(name) => // Leave unchanged if resolution fails. Hopefully will be resolved next round. - val result = q.resolve(name).getOrElse(u) - logger.debug(s"Resolving $u to $result") + val result = q.resolveChildren(name).getOrElse(u) + logDebug(s"Resolving $u to $result") result } } } + /** + * In many dialects of SQL is it valid to sort by attributes that are not present in the SELECT + * clause. This rule detects such queries and adds the required attributes to the original + * projection, so that they will be available during sorting. Another projection is added to + * remove these attributes after sorting. + */ + object ResolveSortReferences extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + case s @ Sort(ordering, p @ Project(projectList, child)) if !s.resolved && p.resolved => + val unresolved = ordering.flatMap(_.collect { case UnresolvedAttribute(name) => name }) + val resolved = unresolved.flatMap(child.resolveChildren) + val requiredAttributes = resolved.collect { case a: Attribute => a }.toSet + + val missingInProject = requiredAttributes -- p.output + if (missingInProject.nonEmpty) { + // Add missing attributes and then project them away after the sort. + Project(projectList, + Sort(ordering, + Project(projectList ++ missingInProject, child))) + } else { + s // Nothing we can do here. Return original plan. + } + case s @ Sort(ordering, a @ Aggregate(grouping, aggs, child)) if !s.resolved && a.resolved => + val unresolved = ordering.flatMap(_.collect { case UnresolvedAttribute(name) => name }) + // A small hack to create an object that will allow us to resolve any references that + // refer to named expressions that are present in the grouping expressions. + val groupingRelation = LocalRelation( + grouping.collect { case ne: NamedExpression => ne.toAttribute } + ) + + logDebug(s"Grouping expressions: $groupingRelation") + val resolved = unresolved.flatMap(groupingRelation.resolve).toSet + val missingInAggs = resolved -- a.outputSet + logDebug(s"Resolved: $resolved Missing in aggs: $missingInAggs") + if (missingInAggs.nonEmpty) { + // Add missing grouping exprs and then project them away after the sort. + Project(a.output, + Sort(ordering, + Aggregate(grouping, aggs ++ missingInAggs, child))) + } else { + s // Nothing we can do here. Return original plan. + } + } + } + /** * Replaces [[UnresolvedFunction]]s with concrete [[catalyst.expressions.Expression Expressions]]. */ @@ -159,7 +205,7 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool object UnresolvedHavingClauseAttributes extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case filter @ Filter(havingCondition, aggregate @ Aggregate(_, originalAggExprs, _)) - if !filter.resolved && aggregate.resolved && containsAggregate(havingCondition) => { + if aggregate.resolved && containsAggregate(havingCondition) => { val evaluatedCondition = Alias(havingCondition, "havingCondition")() val aggExprsWithHaving = evaluatedCondition +: originalAggExprs diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index c0255701b7ba5..760c49fbca4a5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -18,17 +18,49 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.expressions.Expression +import scala.collection.mutable /** A catalog for looking up user defined functions, used by an [[Analyzer]]. */ trait FunctionRegistry { + type FunctionBuilder = Seq[Expression] => Expression + + def registerFunction(name: String, builder: FunctionBuilder): Unit + def lookupFunction(name: String, children: Seq[Expression]): Expression } +trait OverrideFunctionRegistry extends FunctionRegistry { + + val functionBuilders = new mutable.HashMap[String, FunctionBuilder]() + + def registerFunction(name: String, builder: FunctionBuilder) = { + functionBuilders.put(name, builder) + } + + abstract override def lookupFunction(name: String, children: Seq[Expression]): Expression = { + functionBuilders.get(name).map(_(children)).getOrElse(super.lookupFunction(name,children)) + } +} + +class SimpleFunctionRegistry extends FunctionRegistry { + val functionBuilders = new mutable.HashMap[String, FunctionBuilder]() + + def registerFunction(name: String, builder: FunctionBuilder) = { + functionBuilders.put(name, builder) + } + + override def lookupFunction(name: String, children: Seq[Expression]): Expression = { + functionBuilders(name)(children) + } +} + /** * A trivial catalog that returns an error when a function is requested. Used for testing when all * functions are already filled in and the analyser needs only to resolve attribute references. */ object EmptyFunctionRegistry extends FunctionRegistry { + def registerFunction(name: String, builder: FunctionBuilder) = ??? + def lookupFunction(name: String, children: Seq[Expression]): Expression = { throw new UnsupportedOperationException } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 47c7ad076ad07..15eb5982a4a91 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -49,10 +49,21 @@ trait HiveTypeCoercion { BooleanCasts :: StringToIntegralCasts :: FunctionArgumentConversion :: - CastNulls :: + CaseWhenCoercion :: Division :: Nil + trait TypeWidening { + def findTightestCommonType(t1: DataType, t2: DataType): Option[DataType] = { + // Try and find a promotion rule that contains both types in question. + val applicableConversion = + HiveTypeCoercion.allPromotions.find(p => p.contains(t1) && p.contains(t2)) + + // If found return the widest common type, otherwise None + applicableConversion.map(_.filter(t => t == t1 || t == t2).last) + } + } + /** * Applies any changes to [[AttributeReference]] data types that are made by other rules to * instances higher in the query tree. @@ -75,7 +86,7 @@ trait HiveTypeCoercion { // Leave the same if the dataTypes match. case Some(newType) if a.dataType == newType.dataType => a case Some(newType) => - logger.debug(s"Promoting $a to $newType in ${q.simpleString}}") + logDebug(s"Promoting $a to $newType in ${q.simpleString}}") newType } } @@ -133,16 +144,7 @@ trait HiveTypeCoercion { * - LongType to FloatType * - LongType to DoubleType */ - object WidenTypes extends Rule[LogicalPlan] { - - def findTightestCommonType(t1: DataType, t2: DataType): Option[DataType] = { - // Try and find a promotion rule that contains both types in question. - val applicableConversion = - HiveTypeCoercion.allPromotions.find(p => p.contains(t1) && p.contains(t2)) - - // If found return the widest common type, otherwise None - applicableConversion.map(_.filter(t => t == t1 || t == t2).last) - } + object WidenTypes extends Rule[LogicalPlan] with TypeWidening { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case u @ Union(left, right) if u.childrenResolved && !u.resolved => @@ -154,7 +156,7 @@ trait HiveTypeCoercion { (Alias(Cast(l, StringType), l.name)(), r) case (l, r) if l.dataType != r.dataType => - logger.debug(s"Resolving mismatched union input ${l.dataType}, ${r.dataType}") + logDebug(s"Resolving mismatched union input ${l.dataType}, ${r.dataType}") findTightestCommonType(l.dataType, r.dataType).map { widestType => val newLeft = if (l.dataType == widestType) l else Alias(Cast(l, widestType), l.name)() @@ -170,7 +172,7 @@ trait HiveTypeCoercion { val newLeft = if (castedLeft.map(_.dataType) != left.output.map(_.dataType)) { - logger.debug(s"Widening numeric types in union $castedLeft ${left.output}") + logDebug(s"Widening numeric types in union $castedLeft ${left.output}") Project(castedLeft, left) } else { left @@ -178,7 +180,7 @@ trait HiveTypeCoercion { val newRight = if (castedRight.map(_.dataType) != right.output.map(_.dataType)) { - logger.debug(s"Widening numeric types in union $castedRight ${right.output}") + logDebug(s"Widening numeric types in union $castedRight ${right.output}") Project(castedRight, right) } else { right @@ -336,28 +338,34 @@ trait HiveTypeCoercion { } /** - * Ensures that NullType gets casted to some other types under certain circumstances. + * Coerces the type of different branches of a CASE WHEN statement to a common type. */ - object CastNulls extends Rule[LogicalPlan] { + object CaseWhenCoercion extends Rule[LogicalPlan] with TypeWidening { def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { - case cw @ CaseWhen(branches) => + case cw @ CaseWhen(branches) if !cw.resolved && !branches.exists(!_.resolved) => val valueTypes = branches.sliding(2, 2).map { - case Seq(_, value) if value.resolved => Some(value.dataType) - case Seq(elseVal) if elseVal.resolved => Some(elseVal.dataType) - case _ => None + case Seq(_, value) => value.dataType + case Seq(elseVal) => elseVal.dataType }.toSeq - if (valueTypes.distinct.size == 2 && valueTypes.exists(_ == Some(NullType))) { - val otherType = valueTypes.filterNot(_ == Some(NullType))(0).get + + logDebug(s"Input values for null casting ${valueTypes.mkString(",")}") + + if (valueTypes.distinct.size > 1) { + val commonType = valueTypes.reduce { (v1, v2) => + findTightestCommonType(v1, v2) + .getOrElse(sys.error( + s"Types in CASE WHEN must be the same or coercible to a common type: $v1 != $v2")) + } val transformedBranches = branches.sliding(2, 2).map { - case Seq(cond, value) if value.resolved && value.dataType == NullType => - Seq(cond, Cast(value, otherType)) - case Seq(elseVal) if elseVal.resolved && elseVal.dataType == NullType => - Seq(Cast(elseVal, otherType)) + case Seq(cond, value) if value.dataType != commonType => + Seq(cond, Cast(value, commonType)) + case Seq(elseVal) if elseVal.dataType != commonType => + Seq(Cast(elseVal, commonType)) case s => s }.reduce(_ ++ _) CaseWhen(transformedBranches) } else { - // It is possible to have more types due to the possibility of short-circuiting. + // Types match up. Hopefully some other rule fixes whatever is wrong with resolution. cw } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index 7abeb032964e1..a0e25775da6dd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.{errors, trees} import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical.BaseRelation +import org.apache.spark.sql.catalyst.plans.logical.LeafNode import org.apache.spark.sql.catalyst.trees.TreeNode /** @@ -36,7 +36,7 @@ class UnresolvedException[TreeType <: TreeNode[_]](tree: TreeType, function: Str case class UnresolvedRelation( databaseName: Option[String], tableName: String, - alias: Option[String] = None) extends BaseRelation { + alias: Option[String] = None) extends LeafNode { override def output = Nil override lazy val resolved = false } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 5c8c810d9135a..f44521d6381c9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -202,7 +202,7 @@ package object dsl { // Protobuf terminology def required = a.withNullability(false) - def at(ordinal: Int) = BoundReference(ordinal, a) + def at(ordinal: Int) = BoundReference(ordinal, a.dataType, a.nullable) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index 9ce1f01056462..0913f15888780 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -17,72 +17,37 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.trees +import org.apache.spark.Logging import org.apache.spark.sql.catalyst.errors.attachTree -import org.apache.spark.sql.catalyst.plans.QueryPlan -import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.Logging +import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.catalyst.trees /** * A bound reference points to a specific slot in the input tuple, allowing the actual value * to be retrieved more efficiently. However, since operations like column pruning can change * the layout of intermediate tuples, BindReferences should be run after all such transformations. */ -case class BoundReference(ordinal: Int, baseReference: Attribute) - extends Attribute with trees.LeafNode[Expression] { +case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) + extends Expression with trees.LeafNode[Expression] { type EvaluatedType = Any - override def nullable = baseReference.nullable - override def dataType = baseReference.dataType - override def exprId = baseReference.exprId - override def qualifiers = baseReference.qualifiers - override def name = baseReference.name - - override def newInstance = BoundReference(ordinal, baseReference.newInstance) - override def withNullability(newNullability: Boolean) = - BoundReference(ordinal, baseReference.withNullability(newNullability)) - override def withQualifiers(newQualifiers: Seq[String]) = - BoundReference(ordinal, baseReference.withQualifiers(newQualifiers)) + override def references = Set.empty - override def toString = s"$baseReference:$ordinal" + override def toString = s"input[$ordinal]" override def eval(input: Row): Any = input(ordinal) } -/** - * Used to denote operators that do their own binding of attributes internally. - */ -trait NoBind { self: trees.TreeNode[_] => } - -class BindReferences[TreeNode <: QueryPlan[TreeNode]] extends Rule[TreeNode] { - import BindReferences._ - - def apply(plan: TreeNode): TreeNode = { - plan.transform { - case n: NoBind => n.asInstanceOf[TreeNode] - case leafNode if leafNode.children.isEmpty => leafNode - case unaryNode if unaryNode.children.size == 1 => unaryNode.transformExpressions { case e => - bindReference(e, unaryNode.children.head.output) - } - } - } -} - object BindReferences extends Logging { def bindReference[A <: Expression](expression: A, input: Seq[Attribute]): A = { expression.transform { case a: AttributeReference => attachTree(a, "Binding attribute") { val ordinal = input.indexWhere(_.exprId == a.exprId) if (ordinal == -1) { - // TODO: This fallback is required because some operators (such as ScriptTransform) - // produce new attributes that can't be bound. Likely the right thing to do is remove - // this rule and require all operators to explicitly bind to the input schema that - // they specify. - logger.debug(s"Couldn't find $a in ${input.mkString("[", ",", "]")}") - a + sys.error(s"Couldn't find $a in ${input.mkString("[", ",", "]")}") } else { - BoundReference(ordinal, a) + BoundReference(ordinal, a.dataType, a.nullable) } } }.asInstanceOf[A] // Kind of a hack, but safe. TODO: Tighten return type when possible. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index 2c71d2c7b3563..8fc5896974438 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -17,12 +17,13 @@ package org.apache.spark.sql.catalyst.expressions + /** - * Converts a [[Row]] to another Row given a sequence of expression that define each column of the - * new row. If the schema of the input row is specified, then the given expression will be bound to - * that schema. + * A [[Projection]] that is calculated by calling the `eval` of each of the specified expressions. + * @param expressions a sequence of expressions that determine the value of each column of the + * output row. */ -class Projection(expressions: Seq[Expression]) extends (Row => Row) { +class InterpretedProjection(expressions: Seq[Expression]) extends Projection { def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) = this(expressions.map(BindReferences.bindReference(_, inputSchema))) @@ -40,25 +41,25 @@ class Projection(expressions: Seq[Expression]) extends (Row => Row) { } /** - * Converts a [[Row]] to another Row given a sequence of expression that define each column of th - * new row. If the schema of the input row is specified, then the given expression will be bound to - * that schema. - * - * In contrast to a normal projection, a MutableProjection reuses the same underlying row object - * each time an input row is added. This significantly reduces the cost of calculating the - * projection, but means that it is not safe to hold on to a reference to a [[Row]] after `next()` - * has been called on the [[Iterator]] that produced it. Instead, the user must call `Row.copy()` - * and hold on to the returned [[Row]] before calling `next()`. + * A [[MutableProjection]] that is calculated by calling `eval` on each of the specified + * expressions. + * @param expressions a sequence of expressions that determine the value of each column of the + * output row. */ -case class MutableProjection(expressions: Seq[Expression]) extends (Row => Row) { +case class InterpretedMutableProjection(expressions: Seq[Expression]) extends MutableProjection { def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) = this(expressions.map(BindReferences.bindReference(_, inputSchema))) private[this] val exprArray = expressions.toArray - private[this] val mutableRow = new GenericMutableRow(exprArray.size) + private[this] var mutableRow: MutableRow = new GenericMutableRow(exprArray.size) def currentValue: Row = mutableRow - def apply(input: Row): Row = { + override def target(row: MutableRow): MutableProjection = { + mutableRow = row + this + } + + override def apply(input: Row): Row = { var i = 0 while (i < exprArray.length) { mutableRow(i) = exprArray(i).eval(input) @@ -76,6 +77,12 @@ class JoinedRow extends Row { private[this] var row1: Row = _ private[this] var row2: Row = _ + def this(left: Row, right: Row) = { + this() + row1 = left + row2 = right + } + /** Updates this JoinedRow to used point at two new base rows. Returns itself. */ def apply(r1: Row, r2: Row): Row = { row1 = r1 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala index 74ae723686cfe..c9a63e201ef60 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala @@ -32,6 +32,16 @@ object Row { * }}} */ def unapplySeq(row: Row): Some[Seq[Any]] = Some(row) + + /** + * This method can be used to construct a [[Row]] with the given values. + */ + def apply(values: Any*): Row = new GenericRow(values.toArray) + + /** + * This method can be used to construct a [[Row]] from a [[Seq]] of values. + */ + def fromSeq(values: Seq[Any]): Row = new GenericRow(values.toArray) } /** @@ -88,15 +98,6 @@ trait MutableRow extends Row { def setByte(ordinal: Int, value: Byte) def setFloat(ordinal: Int, value: Float) def setString(ordinal: Int, value: String) - - /** - * Experimental - * - * Returns a mutable string builder for the specified column. A given row should return the - * result of any mutations made to the returned buffer next time getString is called for the same - * column. - */ - def getStringBuilder(ordinal: Int): StringBuilder } /** @@ -180,6 +181,35 @@ class GenericRow(protected[catalyst] val values: Array[Any]) extends Row { values(i).asInstanceOf[String] } + // Custom hashCode function that matches the efficient code generated version. + override def hashCode(): Int = { + var result: Int = 37 + + var i = 0 + while (i < values.length) { + val update: Int = + if (isNullAt(i)) { + 0 + } else { + apply(i) match { + case b: Boolean => if (b) 0 else 1 + case b: Byte => b.toInt + case s: Short => s.toInt + case i: Int => i + case l: Long => (l ^ (l >>> 32)).toInt + case f: Float => java.lang.Float.floatToIntBits(f) + case d: Double => + val b = java.lang.Double.doubleToLongBits(d) + (b ^ (b >>> 32)).toInt + case other => other.hashCode() + } + } + result = 37 * result + update + i += 1 + } + result + } + def copy() = this } @@ -187,8 +217,6 @@ class GenericMutableRow(size: Int) extends GenericRow(size) with MutableRow { /** No-arg constructor for serialization. */ def this() = this(0) - def getStringBuilder(ordinal: Int): StringBuilder = ??? - override def setBoolean(ordinal: Int,value: Boolean): Unit = { values(ordinal) = value } override def setByte(ordinal: Int,value: Byte): Unit = { values(ordinal) = value } override def setDouble(ordinal: Int,value: Double): Unit = { values(ordinal) = value } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala index 5e089f7618e0a..95633dd0c9870 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala @@ -27,13 +27,321 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi def references = children.flatMap(_.references).toSet def nullable = true + /** This method has been generated by this script + + (1 to 22).map { x => + val anys = (1 to x).map(x => "Any").reduce(_ + ", " + _) + val evals = (0 to x - 1).map(x => s"children($x).eval(input)").reduce(_ + ",\n " + _) + + s""" + case $x => + function.asInstanceOf[($anys) => Any]( + $evals) + """ + } + + */ + + // scalastyle:off override def eval(input: Row): Any = { children.size match { + case 0 => function.asInstanceOf[() => Any]() case 1 => function.asInstanceOf[(Any) => Any](children(0).eval(input)) case 2 => function.asInstanceOf[(Any, Any) => Any]( children(0).eval(input), children(1).eval(input)) + case 3 => + function.asInstanceOf[(Any, Any, Any) => Any]( + children(0).eval(input), + children(1).eval(input), + children(2).eval(input)) + case 4 => + function.asInstanceOf[(Any, Any, Any, Any) => Any]( + children(0).eval(input), + children(1).eval(input), + children(2).eval(input), + children(3).eval(input)) + case 5 => + function.asInstanceOf[(Any, Any, Any, Any, Any) => Any]( + children(0).eval(input), + children(1).eval(input), + children(2).eval(input), + children(3).eval(input), + children(4).eval(input)) + case 6 => + function.asInstanceOf[(Any, Any, Any, Any, Any, Any) => Any]( + children(0).eval(input), + children(1).eval(input), + children(2).eval(input), + children(3).eval(input), + children(4).eval(input), + children(5).eval(input)) + case 7 => + function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any) => Any]( + children(0).eval(input), + children(1).eval(input), + children(2).eval(input), + children(3).eval(input), + children(4).eval(input), + children(5).eval(input), + children(6).eval(input)) + case 8 => + function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any) => Any]( + children(0).eval(input), + children(1).eval(input), + children(2).eval(input), + children(3).eval(input), + children(4).eval(input), + children(5).eval(input), + children(6).eval(input), + children(7).eval(input)) + case 9 => + function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( + children(0).eval(input), + children(1).eval(input), + children(2).eval(input), + children(3).eval(input), + children(4).eval(input), + children(5).eval(input), + children(6).eval(input), + children(7).eval(input), + children(8).eval(input)) + case 10 => + function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( + children(0).eval(input), + children(1).eval(input), + children(2).eval(input), + children(3).eval(input), + children(4).eval(input), + children(5).eval(input), + children(6).eval(input), + children(7).eval(input), + children(8).eval(input), + children(9).eval(input)) + case 11 => + function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( + children(0).eval(input), + children(1).eval(input), + children(2).eval(input), + children(3).eval(input), + children(4).eval(input), + children(5).eval(input), + children(6).eval(input), + children(7).eval(input), + children(8).eval(input), + children(9).eval(input), + children(10).eval(input)) + case 12 => + function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( + children(0).eval(input), + children(1).eval(input), + children(2).eval(input), + children(3).eval(input), + children(4).eval(input), + children(5).eval(input), + children(6).eval(input), + children(7).eval(input), + children(8).eval(input), + children(9).eval(input), + children(10).eval(input), + children(11).eval(input)) + case 13 => + function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( + children(0).eval(input), + children(1).eval(input), + children(2).eval(input), + children(3).eval(input), + children(4).eval(input), + children(5).eval(input), + children(6).eval(input), + children(7).eval(input), + children(8).eval(input), + children(9).eval(input), + children(10).eval(input), + children(11).eval(input), + children(12).eval(input)) + case 14 => + function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( + children(0).eval(input), + children(1).eval(input), + children(2).eval(input), + children(3).eval(input), + children(4).eval(input), + children(5).eval(input), + children(6).eval(input), + children(7).eval(input), + children(8).eval(input), + children(9).eval(input), + children(10).eval(input), + children(11).eval(input), + children(12).eval(input), + children(13).eval(input)) + case 15 => + function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( + children(0).eval(input), + children(1).eval(input), + children(2).eval(input), + children(3).eval(input), + children(4).eval(input), + children(5).eval(input), + children(6).eval(input), + children(7).eval(input), + children(8).eval(input), + children(9).eval(input), + children(10).eval(input), + children(11).eval(input), + children(12).eval(input), + children(13).eval(input), + children(14).eval(input)) + case 16 => + function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( + children(0).eval(input), + children(1).eval(input), + children(2).eval(input), + children(3).eval(input), + children(4).eval(input), + children(5).eval(input), + children(6).eval(input), + children(7).eval(input), + children(8).eval(input), + children(9).eval(input), + children(10).eval(input), + children(11).eval(input), + children(12).eval(input), + children(13).eval(input), + children(14).eval(input), + children(15).eval(input)) + case 17 => + function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( + children(0).eval(input), + children(1).eval(input), + children(2).eval(input), + children(3).eval(input), + children(4).eval(input), + children(5).eval(input), + children(6).eval(input), + children(7).eval(input), + children(8).eval(input), + children(9).eval(input), + children(10).eval(input), + children(11).eval(input), + children(12).eval(input), + children(13).eval(input), + children(14).eval(input), + children(15).eval(input), + children(16).eval(input)) + case 18 => + function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( + children(0).eval(input), + children(1).eval(input), + children(2).eval(input), + children(3).eval(input), + children(4).eval(input), + children(5).eval(input), + children(6).eval(input), + children(7).eval(input), + children(8).eval(input), + children(9).eval(input), + children(10).eval(input), + children(11).eval(input), + children(12).eval(input), + children(13).eval(input), + children(14).eval(input), + children(15).eval(input), + children(16).eval(input), + children(17).eval(input)) + case 19 => + function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( + children(0).eval(input), + children(1).eval(input), + children(2).eval(input), + children(3).eval(input), + children(4).eval(input), + children(5).eval(input), + children(6).eval(input), + children(7).eval(input), + children(8).eval(input), + children(9).eval(input), + children(10).eval(input), + children(11).eval(input), + children(12).eval(input), + children(13).eval(input), + children(14).eval(input), + children(15).eval(input), + children(16).eval(input), + children(17).eval(input), + children(18).eval(input)) + case 20 => + function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( + children(0).eval(input), + children(1).eval(input), + children(2).eval(input), + children(3).eval(input), + children(4).eval(input), + children(5).eval(input), + children(6).eval(input), + children(7).eval(input), + children(8).eval(input), + children(9).eval(input), + children(10).eval(input), + children(11).eval(input), + children(12).eval(input), + children(13).eval(input), + children(14).eval(input), + children(15).eval(input), + children(16).eval(input), + children(17).eval(input), + children(18).eval(input), + children(19).eval(input)) + case 21 => + function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( + children(0).eval(input), + children(1).eval(input), + children(2).eval(input), + children(3).eval(input), + children(4).eval(input), + children(5).eval(input), + children(6).eval(input), + children(7).eval(input), + children(8).eval(input), + children(9).eval(input), + children(10).eval(input), + children(11).eval(input), + children(12).eval(input), + children(13).eval(input), + children(14).eval(input), + children(15).eval(input), + children(16).eval(input), + children(17).eval(input), + children(18).eval(input), + children(19).eval(input), + children(20).eval(input)) + case 22 => + function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( + children(0).eval(input), + children(1).eval(input), + children(2).eval(input), + children(3).eval(input), + children(4).eval(input), + children(5).eval(input), + children(6).eval(input), + children(7).eval(input), + children(8).eval(input), + children(9).eval(input), + children(10).eval(input), + children(11).eval(input), + children(12).eval(input), + children(13).eval(input), + children(14).eval(input), + children(15).eval(input), + children(16).eval(input), + children(17).eval(input), + children(18).eval(input), + children(19).eval(input), + children(20).eval(input), + children(21).eval(input)) } + // scalastyle:on } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala index e787c59e75723..eb8898900d6a5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala @@ -21,8 +21,16 @@ import scala.language.dynamics import org.apache.spark.sql.catalyst.types.DataType -case object DynamicType extends DataType +/** + * The data type representing [[DynamicRow]] values. + */ +case object DynamicType extends DataType { + def simpleString: String = "dynamic" +} +/** + * Wrap a [[Row]] as a [[DynamicRow]]. + */ case class WrapDynamic(children: Seq[Attribute]) extends Expression { type EvaluatedType = DynamicRow @@ -37,6 +45,11 @@ case class WrapDynamic(children: Seq[Attribute]) extends Expression { } } +/** + * DynamicRows use scala's Dynamic trait to emulate an ORM of in a dynamically typed language. + * Since the type of the column is not known at compile time, all attributes are converted to + * strings before being passed to the function. + */ class DynamicRow(val schema: Seq[Attribute], values: Array[Any]) extends GenericRow(values) with Dynamic { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala new file mode 100644 index 0000000000000..5b398695bf560 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -0,0 +1,468 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.codegen + +import com.google.common.cache.{CacheLoader, CacheBuilder} + +import scala.language.existentials + +import org.apache.spark.Logging +import org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.types._ + +/** + * A base class for generators of byte code to perform expression evaluation. Includes a set of + * helpers for referring to Catalyst types and building trees that perform evaluation of individual + * expressions. + */ +abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Logging { + import scala.reflect.runtime.{universe => ru} + import scala.reflect.runtime.universe._ + + import scala.tools.reflect.ToolBox + + protected val toolBox = runtimeMirror(getClass.getClassLoader).mkToolBox() + + protected val rowType = typeOf[Row] + protected val mutableRowType = typeOf[MutableRow] + protected val genericRowType = typeOf[GenericRow] + protected val genericMutableRowType = typeOf[GenericMutableRow] + + protected val projectionType = typeOf[Projection] + protected val mutableProjectionType = typeOf[MutableProjection] + + private val curId = new java.util.concurrent.atomic.AtomicInteger() + private val javaSeparator = "$" + + /** + * Generates a class for a given input expression. Called when there is not cached code + * already available. + */ + protected def create(in: InType): OutType + + /** + * Canonicalizes an input expression. Used to avoid double caching expressions that differ only + * cosmetically. + */ + protected def canonicalize(in: InType): InType + + /** Binds an input expression to a given input schema */ + protected def bind(in: InType, inputSchema: Seq[Attribute]): InType + + /** + * A cache of generated classes. + * + * From the Guava Docs: A Cache is similar to ConcurrentMap, but not quite the same. The most + * fundamental difference is that a ConcurrentMap persists all elements that are added to it until + * they are explicitly removed. A Cache on the other hand is generally configured to evict entries + * automatically, in order to constrain its memory footprint + */ + protected val cache = CacheBuilder.newBuilder() + .maximumSize(1000) + .build( + new CacheLoader[InType, OutType]() { + override def load(in: InType): OutType = globalLock.synchronized { + create(in) + } + }) + + /** Generates the requested evaluator binding the given expression(s) to the inputSchema. */ + def apply(expressions: InType, inputSchema: Seq[Attribute]): OutType = + apply(bind(expressions, inputSchema)) + + /** Generates the requested evaluator given already bound expression(s). */ + def apply(expressions: InType): OutType = cache.get(canonicalize(expressions)) + + /** + * Returns a term name that is unique within this instance of a `CodeGenerator`. + * + * (Since we aren't in a macro context we do not seem to have access to the built in `freshName` + * function.) + */ + protected def freshName(prefix: String): TermName = { + newTermName(s"$prefix$javaSeparator${curId.getAndIncrement}") + } + + /** + * Scala ASTs for evaluating an [[Expression]] given a [[Row]] of input. + * + * @param code The sequence of statements required to evaluate the expression. + * @param nullTerm A term that holds a boolean value representing whether the expression evaluated + * to null. + * @param primitiveTerm A term for a possible primitive value of the result of the evaluation. Not + * valid if `nullTerm` is set to `false`. + * @param objectTerm A possibly boxed version of the result of evaluating this expression. + */ + protected case class EvaluatedExpression( + code: Seq[Tree], + nullTerm: TermName, + primitiveTerm: TermName, + objectTerm: TermName) + + /** + * Given an expression tree returns an [[EvaluatedExpression]], which contains Scala trees that + * can be used to determine the result of evaluating the expression on an input row. + */ + def expressionEvaluator(e: Expression): EvaluatedExpression = { + val primitiveTerm = freshName("primitiveTerm") + val nullTerm = freshName("nullTerm") + val objectTerm = freshName("objectTerm") + + implicit class Evaluate1(e: Expression) { + def castOrNull(f: TermName => Tree, dataType: DataType): Seq[Tree] = { + val eval = expressionEvaluator(e) + eval.code ++ + q""" + val $nullTerm = ${eval.nullTerm} + val $primitiveTerm = + if($nullTerm) + ${defaultPrimitive(dataType)} + else + ${f(eval.primitiveTerm)} + """.children + } + } + + implicit class Evaluate2(expressions: (Expression, Expression)) { + + /** + * Short hand for generating binary evaluation code, which depends on two sub-evaluations of + * the same type. If either of the sub-expressions is null, the result of this computation + * is assumed to be null. + * + * @param f a function from two primitive term names to a tree that evaluates them. + */ + def evaluate(f: (TermName, TermName) => Tree): Seq[Tree] = + evaluateAs(expressions._1.dataType)(f) + + def evaluateAs(resultType: DataType)(f: (TermName, TermName) => Tree): Seq[Tree] = { + // TODO: Right now some timestamp tests fail if we enforce this... + if (expressions._1.dataType != expressions._2.dataType) { + log.warn(s"${expressions._1.dataType} != ${expressions._2.dataType}") + } + + val eval1 = expressionEvaluator(expressions._1) + val eval2 = expressionEvaluator(expressions._2) + val resultCode = f(eval1.primitiveTerm, eval2.primitiveTerm) + + eval1.code ++ eval2.code ++ + q""" + val $nullTerm = ${eval1.nullTerm} || ${eval2.nullTerm} + val $primitiveTerm: ${termForType(resultType)} = + if($nullTerm) { + ${defaultPrimitive(resultType)} + } else { + $resultCode.asInstanceOf[${termForType(resultType)}] + } + """.children : Seq[Tree] + } + } + + val inputTuple = newTermName(s"i") + + // TODO: Skip generation of null handling code when expression are not nullable. + val primitiveEvaluation: PartialFunction[Expression, Seq[Tree]] = { + case b @ BoundReference(ordinal, dataType, nullable) => + val nullValue = q"$inputTuple.isNullAt($ordinal)" + q""" + val $nullTerm: Boolean = $nullValue + val $primitiveTerm: ${termForType(dataType)} = + if($nullTerm) + ${defaultPrimitive(dataType)} + else + ${getColumn(inputTuple, dataType, ordinal)} + """.children + + case expressions.Literal(null, dataType) => + q""" + val $nullTerm = true + val $primitiveTerm: ${termForType(dataType)} = null.asInstanceOf[${termForType(dataType)}] + """.children + + case expressions.Literal(value: Boolean, dataType) => + q""" + val $nullTerm = ${value == null} + val $primitiveTerm: ${termForType(dataType)} = $value + """.children + + case expressions.Literal(value: String, dataType) => + q""" + val $nullTerm = ${value == null} + val $primitiveTerm: ${termForType(dataType)} = $value + """.children + + case expressions.Literal(value: Int, dataType) => + q""" + val $nullTerm = ${value == null} + val $primitiveTerm: ${termForType(dataType)} = $value + """.children + + case expressions.Literal(value: Long, dataType) => + q""" + val $nullTerm = ${value == null} + val $primitiveTerm: ${termForType(dataType)} = $value + """.children + + case Cast(e @ BinaryType(), StringType) => + val eval = expressionEvaluator(e) + eval.code ++ + q""" + val $nullTerm = ${eval.nullTerm} + val $primitiveTerm = + if($nullTerm) + ${defaultPrimitive(StringType)} + else + new String(${eval.primitiveTerm}.asInstanceOf[Array[Byte]]) + """.children + + case Cast(child @ NumericType(), IntegerType) => + child.castOrNull(c => q"$c.toInt", IntegerType) + + case Cast(child @ NumericType(), LongType) => + child.castOrNull(c => q"$c.toLong", LongType) + + case Cast(child @ NumericType(), DoubleType) => + child.castOrNull(c => q"$c.toDouble", DoubleType) + + case Cast(child @ NumericType(), FloatType) => + child.castOrNull(c => q"$c.toFloat", IntegerType) + + // Special handling required for timestamps in hive test cases since the toString function + // does not match the expected output. + case Cast(e, StringType) if e.dataType != TimestampType => + val eval = expressionEvaluator(e) + eval.code ++ + q""" + val $nullTerm = ${eval.nullTerm} + val $primitiveTerm = + if($nullTerm) + ${defaultPrimitive(StringType)} + else + ${eval.primitiveTerm}.toString + """.children + + case EqualTo(e1, e2) => + (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => q"$eval1 == $eval2" } + + /* TODO: Fix null semantics. + case In(e1, list) if !list.exists(!_.isInstanceOf[expressions.Literal]) => + val eval = expressionEvaluator(e1) + + val checks = list.map { + case expressions.Literal(v: String, dataType) => + q"if(${eval.primitiveTerm} == $v) return true" + case expressions.Literal(v: Int, dataType) => + q"if(${eval.primitiveTerm} == $v) return true" + } + + val funcName = newTermName(s"isIn${curId.getAndIncrement()}") + + q""" + def $funcName: Boolean = { + ..${eval.code} + if(${eval.nullTerm}) return false + ..$checks + return false + } + val $nullTerm = false + val $primitiveTerm = $funcName + """.children + */ + + case GreaterThan(e1 @ NumericType(), e2 @ NumericType()) => + (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => q"$eval1 > $eval2" } + case GreaterThanOrEqual(e1 @ NumericType(), e2 @ NumericType()) => + (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => q"$eval1 >= $eval2" } + case LessThan(e1 @ NumericType(), e2 @ NumericType()) => + (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => q"$eval1 < $eval2" } + case LessThanOrEqual(e1 @ NumericType(), e2 @ NumericType()) => + (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => q"$eval1 <= $eval2" } + + case And(e1, e2) => + val eval1 = expressionEvaluator(e1) + val eval2 = expressionEvaluator(e2) + + eval1.code ++ eval2.code ++ + q""" + var $nullTerm = false + var $primitiveTerm: ${termForType(BooleanType)} = false + + if ((!${eval1.nullTerm} && !${eval1.primitiveTerm}) || + (!${eval2.nullTerm} && !${eval2.primitiveTerm})) { + $nullTerm = false + $primitiveTerm = false + } else if (${eval1.nullTerm} || ${eval2.nullTerm} ) { + $nullTerm = true + } else { + $nullTerm = false + $primitiveTerm = true + } + """.children + + case Or(e1, e2) => + val eval1 = expressionEvaluator(e1) + val eval2 = expressionEvaluator(e2) + + eval1.code ++ eval2.code ++ + q""" + var $nullTerm = false + var $primitiveTerm: ${termForType(BooleanType)} = false + + if ((!${eval1.nullTerm} && ${eval1.primitiveTerm}) || + (!${eval2.nullTerm} && ${eval2.primitiveTerm})) { + $nullTerm = false + $primitiveTerm = true + } else if (${eval1.nullTerm} || ${eval2.nullTerm} ) { + $nullTerm = true + } else { + $nullTerm = false + $primitiveTerm = false + } + """.children + + case Not(child) => + // Uh, bad function name... + child.castOrNull(c => q"!$c", BooleanType) + + case Add(e1, e2) => (e1, e2) evaluate { case (eval1, eval2) => q"$eval1 + $eval2" } + case Subtract(e1, e2) => (e1, e2) evaluate { case (eval1, eval2) => q"$eval1 - $eval2" } + case Multiply(e1, e2) => (e1, e2) evaluate { case (eval1, eval2) => q"$eval1 * $eval2" } + case Divide(e1, e2) => (e1, e2) evaluate { case (eval1, eval2) => q"$eval1 / $eval2" } + + case IsNotNull(e) => + val eval = expressionEvaluator(e) + q""" + ..${eval.code} + var $nullTerm = false + var $primitiveTerm: ${termForType(BooleanType)} = !${eval.nullTerm} + """.children + + case IsNull(e) => + val eval = expressionEvaluator(e) + q""" + ..${eval.code} + var $nullTerm = false + var $primitiveTerm: ${termForType(BooleanType)} = ${eval.nullTerm} + """.children + + case c @ Coalesce(children) => + q""" + var $nullTerm = true + var $primitiveTerm: ${termForType(c.dataType)} = ${defaultPrimitive(c.dataType)} + """.children ++ + children.map { c => + val eval = expressionEvaluator(c) + q""" + if($nullTerm) { + ..${eval.code} + if(!${eval.nullTerm}) { + $nullTerm = false + $primitiveTerm = ${eval.primitiveTerm} + } + } + """ + } + + case i @ expressions.If(condition, trueValue, falseValue) => + val condEval = expressionEvaluator(condition) + val trueEval = expressionEvaluator(trueValue) + val falseEval = expressionEvaluator(falseValue) + + q""" + var $nullTerm = false + var $primitiveTerm: ${termForType(i.dataType)} = ${defaultPrimitive(i.dataType)} + ..${condEval.code} + if(!${condEval.nullTerm} && ${condEval.primitiveTerm}) { + ..${trueEval.code} + $nullTerm = ${trueEval.nullTerm} + $primitiveTerm = ${trueEval.primitiveTerm} + } else { + ..${falseEval.code} + $nullTerm = ${falseEval.nullTerm} + $primitiveTerm = ${falseEval.primitiveTerm} + } + """.children + } + + // If there was no match in the partial function above, we fall back on calling the interpreted + // expression evaluator. + val code: Seq[Tree] = + primitiveEvaluation.lift.apply(e).getOrElse { + log.debug(s"No rules to generate $e") + val tree = reify { e } + q""" + val $objectTerm = $tree.eval(i) + val $nullTerm = $objectTerm == null + val $primitiveTerm = $objectTerm.asInstanceOf[${termForType(e.dataType)}] + """.children + } + + EvaluatedExpression(code, nullTerm, primitiveTerm, objectTerm) + } + + protected def getColumn(inputRow: TermName, dataType: DataType, ordinal: Int) = { + dataType match { + case dt @ NativeType() => q"$inputRow.${accessorForType(dt)}($ordinal)" + case _ => q"$inputRow.apply($ordinal).asInstanceOf[${termForType(dataType)}]" + } + } + + protected def setColumn( + destinationRow: TermName, + dataType: DataType, + ordinal: Int, + value: TermName) = { + dataType match { + case dt @ NativeType() => q"$destinationRow.${mutatorForType(dt)}($ordinal, $value)" + case _ => q"$destinationRow.update($ordinal, $value)" + } + } + + protected def accessorForType(dt: DataType) = newTermName(s"get${primitiveForType(dt)}") + protected def mutatorForType(dt: DataType) = newTermName(s"set${primitiveForType(dt)}") + + protected def primitiveForType(dt: DataType) = dt match { + case IntegerType => "Int" + case LongType => "Long" + case ShortType => "Short" + case ByteType => "Byte" + case DoubleType => "Double" + case FloatType => "Float" + case BooleanType => "Boolean" + case StringType => "String" + } + + protected def defaultPrimitive(dt: DataType) = dt match { + case BooleanType => ru.Literal(Constant(false)) + case FloatType => ru.Literal(Constant(-1.0.toFloat)) + case StringType => ru.Literal(Constant("")) + case ShortType => ru.Literal(Constant(-1.toShort)) + case LongType => ru.Literal(Constant(1L)) + case ByteType => ru.Literal(Constant(-1.toByte)) + case DoubleType => ru.Literal(Constant(-1.toDouble)) + case DecimalType => ru.Literal(Constant(-1)) // Will get implicity converted as needed. + case IntegerType => ru.Literal(Constant(-1)) + case _ => ru.Literal(Constant(null)) + } + + protected def termForType(dt: DataType) = dt match { + case n: NativeType => n.tag + case _ => typeTag[Any] + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala new file mode 100644 index 0000000000000..a419fd7ecb39b --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.codegen + +import org.apache.spark.sql.catalyst.expressions._ + +/** + * Generates byte code that produces a [[MutableRow]] object that can update itself based on a new + * input [[Row]] for a fixed set of [[Expression Expressions]]. + */ +object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => MutableProjection] { + import scala.reflect.runtime.{universe => ru} + import scala.reflect.runtime.universe._ + + val mutableRowName = newTermName("mutableRow") + + protected def canonicalize(in: Seq[Expression]): Seq[Expression] = + in.map(ExpressionCanonicalizer(_)) + + protected def bind(in: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] = + in.map(BindReferences.bindReference(_, inputSchema)) + + protected def create(expressions: Seq[Expression]): (() => MutableProjection) = { + val projectionCode = expressions.zipWithIndex.flatMap { case (e, i) => + val evaluationCode = expressionEvaluator(e) + + evaluationCode.code :+ + q""" + if(${evaluationCode.nullTerm}) + mutableRow.setNullAt($i) + else + ${setColumn(mutableRowName, e.dataType, i, evaluationCode.primitiveTerm)} + """ + } + + val code = + q""" + () => { new $mutableProjectionType { + + private[this] var $mutableRowName: $mutableRowType = + new $genericMutableRowType(${expressions.size}) + + def target(row: $mutableRowType): $mutableProjectionType = { + $mutableRowName = row + this + } + + /* Provide immutable access to the last projected row. */ + def currentValue: $rowType = mutableRow + + def apply(i: $rowType): $rowType = { + ..$projectionCode + mutableRow + } + } } + """ + + log.debug(s"code for ${expressions.mkString(",")}:\n$code") + toolBox.eval(code).asInstanceOf[() => MutableProjection] + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala new file mode 100644 index 0000000000000..094ff14552283 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.codegen + +import org.apache.spark.Logging +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.types.{StringType, NumericType} + +/** + * Generates bytecode for an [[Ordering]] of [[Row Rows]] for a given set of + * [[Expression Expressions]]. + */ +object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[Row]] with Logging { + import scala.reflect.runtime.{universe => ru} + import scala.reflect.runtime.universe._ + + protected def canonicalize(in: Seq[SortOrder]): Seq[SortOrder] = + in.map(ExpressionCanonicalizer(_).asInstanceOf[SortOrder]) + + protected def bind(in: Seq[SortOrder], inputSchema: Seq[Attribute]): Seq[SortOrder] = + in.map(BindReferences.bindReference(_, inputSchema)) + + protected def create(ordering: Seq[SortOrder]): Ordering[Row] = { + val a = newTermName("a") + val b = newTermName("b") + val comparisons = ordering.zipWithIndex.map { case (order, i) => + val evalA = expressionEvaluator(order.child) + val evalB = expressionEvaluator(order.child) + + val compare = order.child.dataType match { + case _: NumericType => + q""" + val comp = ${evalA.primitiveTerm} - ${evalB.primitiveTerm} + if(comp != 0) { + return ${if (order.direction == Ascending) q"comp.toInt" else q"-comp.toInt"} + } + """ + case StringType => + if (order.direction == Ascending) { + q"""return ${evalA.primitiveTerm}.compare(${evalB.primitiveTerm})""" + } else { + q"""return ${evalB.primitiveTerm}.compare(${evalA.primitiveTerm})""" + } + } + + q""" + i = $a + ..${evalA.code} + i = $b + ..${evalB.code} + if (${evalA.nullTerm} && ${evalB.nullTerm}) { + // Nothing + } else if (${evalA.nullTerm}) { + return ${if (order.direction == Ascending) q"-1" else q"1"} + } else if (${evalB.nullTerm}) { + return ${if (order.direction == Ascending) q"1" else q"-1"} + } else { + $compare + } + """ + } + + val q"class $orderingName extends $orderingType { ..$body }" = reify { + class SpecificOrdering extends Ordering[Row] { + val o = ordering + } + }.tree.children.head + + val code = q""" + class $orderingName extends $orderingType { + ..$body + def compare(a: $rowType, b: $rowType): Int = { + var i: $rowType = null // Holds current row being evaluated. + ..$comparisons + return 0 + } + } + new $orderingName() + """ + logDebug(s"Generated Ordering: $code") + toolBox.eval(code).asInstanceOf[Ordering[Row]] + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala new file mode 100644 index 0000000000000..2a0935c790cf3 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.codegen + +import org.apache.spark.sql.catalyst.expressions._ + +/** + * Generates bytecode that evaluates a boolean [[Expression]] on a given input [[Row]]. + */ +object GeneratePredicate extends CodeGenerator[Expression, (Row) => Boolean] { + import scala.reflect.runtime.{universe => ru} + import scala.reflect.runtime.universe._ + + protected def canonicalize(in: Expression): Expression = ExpressionCanonicalizer(in) + + protected def bind(in: Expression, inputSchema: Seq[Attribute]): Expression = + BindReferences.bindReference(in, inputSchema) + + protected def create(predicate: Expression): ((Row) => Boolean) = { + val cEval = expressionEvaluator(predicate) + + val code = + q""" + (i: $rowType) => { + ..${cEval.code} + if (${cEval.nullTerm}) false else ${cEval.primitiveTerm} + } + """ + + log.debug(s"Generated predicate '$predicate':\n$code") + toolBox.eval(code).asInstanceOf[Row => Boolean] + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala new file mode 100644 index 0000000000000..77fa02c13de30 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala @@ -0,0 +1,219 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.codegen + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.types._ + + +/** + * Generates bytecode that produces a new [[Row]] object based on a fixed set of input + * [[Expression Expressions]] and a given input [[Row]]. The returned [[Row]] object is custom + * generated based on the output types of the [[Expression]] to avoid boxing of primitive values. + */ +object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { + import scala.reflect.runtime.{universe => ru} + import scala.reflect.runtime.universe._ + + protected def canonicalize(in: Seq[Expression]): Seq[Expression] = + in.map(ExpressionCanonicalizer(_)) + + protected def bind(in: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] = + in.map(BindReferences.bindReference(_, inputSchema)) + + // Make Mutablility optional... + protected def create(expressions: Seq[Expression]): Projection = { + val tupleLength = ru.Literal(Constant(expressions.length)) + val lengthDef = q"final val length = $tupleLength" + + /* TODO: Configurable... + val nullFunctions = + q""" + private final val nullSet = new org.apache.spark.util.collection.BitSet(length) + final def setNullAt(i: Int) = nullSet.set(i) + final def isNullAt(i: Int) = nullSet.get(i) + """ + */ + + val nullFunctions = + q""" + private[this] var nullBits = new Array[Boolean](${expressions.size}) + final def setNullAt(i: Int) = { nullBits(i) = true } + final def isNullAt(i: Int) = nullBits(i) + """.children + + val tupleElements = expressions.zipWithIndex.flatMap { + case (e, i) => + val elementName = newTermName(s"c$i") + val evaluatedExpression = expressionEvaluator(e) + val iLit = ru.Literal(Constant(i)) + + q""" + var ${newTermName(s"c$i")}: ${termForType(e.dataType)} = _ + { + ..${evaluatedExpression.code} + if(${evaluatedExpression.nullTerm}) + setNullAt($iLit) + else + $elementName = ${evaluatedExpression.primitiveTerm} + } + """.children : Seq[Tree] + } + + val iteratorFunction = { + val allColumns = (0 until expressions.size).map { i => + val iLit = ru.Literal(Constant(i)) + q"if(isNullAt($iLit)) { null } else { ${newTermName(s"c$i")} }" + } + q"final def iterator = Iterator[Any](..$allColumns)" + } + + val accessorFailure = q"""scala.sys.error("Invalid ordinal:" + i)""" + val applyFunction = { + val cases = (0 until expressions.size).map { i => + val ordinal = ru.Literal(Constant(i)) + val elementName = newTermName(s"c$i") + val iLit = ru.Literal(Constant(i)) + + q"if(i == $ordinal) { if(isNullAt($i)) return null else return $elementName }" + } + q"final def apply(i: Int): Any = { ..$cases; $accessorFailure }" + } + + val updateFunction = { + val cases = expressions.zipWithIndex.map {case (e, i) => + val ordinal = ru.Literal(Constant(i)) + val elementName = newTermName(s"c$i") + val iLit = ru.Literal(Constant(i)) + + q""" + if(i == $ordinal) { + if(value == null) { + setNullAt(i) + } else { + $elementName = value.asInstanceOf[${termForType(e.dataType)}] + return + } + }""" + } + q"final def update(i: Int, value: Any): Unit = { ..$cases; $accessorFailure }" + } + + val specificAccessorFunctions = NativeType.all.map { dataType => + val ifStatements = expressions.zipWithIndex.flatMap { + case (e, i) if e.dataType == dataType => + val elementName = newTermName(s"c$i") + // TODO: The string of ifs gets pretty inefficient as the row grows in size. + // TODO: Optional null checks? + q"if(i == $i) return $elementName" :: Nil + case _ => Nil + } + + q""" + final def ${accessorForType(dataType)}(i: Int):${termForType(dataType)} = { + ..$ifStatements; + $accessorFailure + }""" + } + + val specificMutatorFunctions = NativeType.all.map { dataType => + val ifStatements = expressions.zipWithIndex.flatMap { + case (e, i) if e.dataType == dataType => + val elementName = newTermName(s"c$i") + // TODO: The string of ifs gets pretty inefficient as the row grows in size. + // TODO: Optional null checks? + q"if(i == $i) { $elementName = value; return }" :: Nil + case _ => Nil + } + + q""" + final def ${mutatorForType(dataType)}(i: Int, value: ${termForType(dataType)}): Unit = { + ..$ifStatements; + $accessorFailure + }""" + } + + val hashValues = expressions.zipWithIndex.map { case (e,i) => + val elementName = newTermName(s"c$i") + val nonNull = e.dataType match { + case BooleanType => q"if ($elementName) 0 else 1" + case ByteType | ShortType | IntegerType => q"$elementName.toInt" + case LongType => q"($elementName ^ ($elementName >>> 32)).toInt" + case FloatType => q"java.lang.Float.floatToIntBits($elementName)" + case DoubleType => + q"{ val b = java.lang.Double.doubleToLongBits($elementName); (b ^ (b >>>32)).toInt }" + case _ => q"$elementName.hashCode" + } + q"if (isNullAt($i)) 0 else $nonNull" + } + + val hashUpdates: Seq[Tree] = hashValues.map(v => q"""result = 37 * result + $v""": Tree) + + val hashCodeFunction = + q""" + override def hashCode(): Int = { + var result: Int = 37 + ..$hashUpdates + result + } + """ + + val columnChecks = (0 until expressions.size).map { i => + val elementName = newTermName(s"c$i") + q"if (this.$elementName != specificType.$elementName) return false" + } + + val equalsFunction = + q""" + override def equals(other: Any): Boolean = other match { + case specificType: SpecificRow => + ..$columnChecks + return true + case other => super.equals(other) + } + """ + + val copyFunction = + q""" + final def copy() = new $genericRowType(this.toArray) + """ + + val classBody = + nullFunctions ++ ( + lengthDef +: + iteratorFunction +: + applyFunction +: + updateFunction +: + equalsFunction +: + hashCodeFunction +: + copyFunction +: + (tupleElements ++ specificAccessorFunctions ++ specificMutatorFunctions)) + + val code = q""" + final class SpecificRow(i: $rowType) extends $mutableRowType { + ..$classBody + } + + new $projectionType { def apply(r: $rowType) = new SpecificRow(r) } + """ + + log.debug( + s"MutableRow, initExprs: ${expressions.mkString(",")} code:\n${toolBox.typeCheck(code)}") + toolBox.eval(code).asInstanceOf[Projection] + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala new file mode 100644 index 0000000000000..80c7dfd376c96 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.catalyst.rules +import org.apache.spark.sql.catalyst.util + +/** + * A collection of generators that build custom bytecode at runtime for performing the evaluation + * of catalyst expression. + */ +package object codegen { + + /** + * A lock to protect invoking the scala compiler at runtime, since it is not thread safe in Scala + * 2.10. + */ + protected[codegen] val globalLock = org.apache.spark.sql.catalyst.ScalaReflectionLock + + /** Canonicalizes an expression so those that differ only by names can reuse the same code. */ + object ExpressionCanonicalizer extends rules.RuleExecutor[Expression] { + val batches = + Batch("CleanExpressions", FixedPoint(20), CleanExpressions) :: Nil + + object CleanExpressions extends rules.Rule[Expression] { + def apply(e: Expression): Expression = e transform { + case Alias(c, _) => c + } + } + } + + /** + * :: DeveloperApi :: + * Dumps the bytecode from a class to the screen using javap. + */ + @DeveloperApi + object DumpByteCode { + import scala.sys.process._ + val dumpDirectory = util.getTempFilePath("sparkSqlByteCode") + dumpDirectory.mkdir() + + def apply(obj: Any): Unit = { + val generatedClass = obj.getClass + val classLoader = + generatedClass + .getClassLoader + .asInstanceOf[scala.tools.nsc.interpreter.AbstractFileClassLoader] + val generatedBytes = classLoader.classBytes(generatedClass.getName) + + val packageDir = new java.io.File(dumpDirectory, generatedClass.getPackage.getName) + if (!packageDir.exists()) { packageDir.mkdir() } + + val classFile = + new java.io.File(packageDir, generatedClass.getName.split("\\.").last + ".class") + + val outfile = new java.io.FileOutputStream(classFile) + outfile.write(generatedBytes) + outfile.close() + + println( + s"javap -p -v -classpath ${dumpDirectory.getCanonicalPath} ${generatedClass.getName}".!!) + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala index 5d3bb25ad568c..c1154eb81c319 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.expressions +import scala.collection.Map + import org.apache.spark.sql.catalyst.types._ /** @@ -31,8 +33,8 @@ case class GetItem(child: Expression, ordinal: Expression) extends Expression { override def foldable = child.foldable && ordinal.foldable override def references = children.flatMap(_.references).toSet def dataType = child.dataType match { - case ArrayType(dt) => dt - case MapType(_, vt) => vt + case ArrayType(dt, _) => dt + case MapType(_, vt, _) => vt } override lazy val resolved = childrenResolved && @@ -61,7 +63,6 @@ case class GetItem(child: Expression, ordinal: Expression) extends Expression { } } else { val baseValue = value.asInstanceOf[Map[Any, _]] - val key = ordinal.eval(input) baseValue.get(key).orNull } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index dd78614754e12..3d41acb79e5fd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.expressions +import scala.collection.Map + import org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.catalyst.types._ @@ -84,8 +86,8 @@ case class Explode(attributeNames: Seq[String], child: Expression) (child.dataType.isInstanceOf[ArrayType] || child.dataType.isInstanceOf[MapType]) private lazy val elementTypes = child.dataType match { - case ArrayType(et) => et :: Nil - case MapType(kt,vt) => kt :: vt :: Nil + case ArrayType(et, _) => et :: Nil + case MapType(kt,vt, _) => kt :: vt :: Nil } // TODO: Move this pattern into Generator. @@ -102,10 +104,10 @@ case class Explode(attributeNames: Seq[String], child: Expression) override def eval(input: Row): TraversableOnce[Row] = { child.dataType match { - case ArrayType(_) => + case ArrayType(_, _) => val inputArray = child.eval(input).asInstanceOf[Seq[Any]] if (inputArray == null) Nil else inputArray.map(v => new GenericRow(Array(v))) - case MapType(_, _) => + case MapType(_, _, _) => val inputMap = child.eval(input).asInstanceOf[Map[Any,Any]] if (inputMap == null) Nil else inputMap.map { case (k,v) => new GenericRow(Array(k,v)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 934bad8c27294..02d04762629f5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -28,8 +28,8 @@ object NamedExpression { } /** - * A globally (within this JVM) id for a given named expression. - * Used to identify with attribute output by a relation is being + * A globally unique (within this JVM) id for a given named expression. + * Used to identify which attribute output by a relation is being * referenced in a subsequent computation. */ case class ExprId(id: Long) @@ -134,7 +134,7 @@ case class AttributeReference(name: String, dataType: DataType, nullable: Boolea /** * Returns a copy of this [[AttributeReference]] with changed nullability. */ - override def withNullability(newNullability: Boolean) = { + override def withNullability(newNullability: Boolean): AttributeReference = { if (nullable == newNullability) { this } else { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala index b6f2451b52e1f..55d95991c5f11 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala @@ -47,4 +47,30 @@ package org.apache.spark.sql.catalyst * ==Evaluation== * The result of expressions can be evaluated using the `Expression.apply(Row)` method. */ -package object expressions +package object expressions { + + /** + * Converts a [[Row]] to another Row given a sequence of expression that define each column of the + * new row. If the schema of the input row is specified, then the given expression will be bound + * to that schema. + */ + abstract class Projection extends (Row => Row) + + /** + * Converts a [[Row]] to another Row given a sequence of expression that define each column of the + * new row. If the schema of the input row is specified, then the given expression will be bound + * to that schema. + * + * In contrast to a normal projection, a MutableProjection reuses the same underlying row object + * each time an input row is added. This significantly reduces the cost of calculating the + * projection, but means that it is not safe to hold on to a reference to a [[Row]] after `next()` + * has been called on the [[Iterator]] that produced it. Instead, the user must call `Row.copy()` + * and hold on to the returned [[Row]] before calling `next()`. + */ + abstract class MutableProjection extends Projection { + def currentValue: Row + + /** Uses the given row to store the output of the projection. */ + def target(row: MutableRow): MutableProjection + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 06b94a98d3cd0..5976b0ddf3e03 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -23,6 +23,9 @@ import org.apache.spark.sql.catalyst.types.BooleanType object InterpretedPredicate { + def apply(expression: Expression, inputSchema: Seq[Attribute]): (Row => Boolean) = + apply(BindReferences.bindReference(expression, inputSchema)) + def apply(expression: Expression): (Row => Boolean) = { (r: Row) => expression.eval(r).asInstanceOf[Boolean] } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/package.scala new file mode 100644 index 0000000000000..bdd07bbeb2230 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/package.scala @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +package object catalyst { + /** + * A JVM-global lock that should be used to prevent thread safety issues when using things in + * scala.reflect.*. Note that Scala Reflection API is made thread-safe in 2.11, but not yet for + * 2.10.* builds. See SI-6240 for more details. + */ + protected[catalyst] object ScalaReflectionLock + +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala index 67833664b35ae..5839c9f7c43ef 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.planning -import org.apache.spark.sql.Logging +import org.apache.spark.Logging import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.trees.TreeNode diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 026692abe067d..90923fe31a063 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -19,9 +19,8 @@ package org.apache.spark.sql.catalyst.planning import scala.annotation.tailrec -import org.apache.spark.sql.Logging - import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.Logging import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ @@ -104,6 +103,77 @@ object PhysicalOperation extends PredicateHelper { } } +/** + * Matches a logical aggregation that can be performed on distributed data in two steps. The first + * operates on the data in each partition performing partial aggregation for each group. The second + * occurs after the shuffle and completes the aggregation. + * + * This pattern will only match if all aggregate expressions can be computed partially and will + * return the rewritten aggregation expressions for both phases. + * + * The returned values for this match are as follows: + * - Grouping attributes for the final aggregation. + * - Aggregates for the final aggregation. + * - Grouping expressions for the partial aggregation. + * - Partial aggregate expressions. + * - Input to the aggregation. + */ +object PartialAggregation { + type ReturnType = + (Seq[Attribute], Seq[NamedExpression], Seq[Expression], Seq[NamedExpression], LogicalPlan) + + def unapply(plan: LogicalPlan): Option[ReturnType] = plan match { + case logical.Aggregate(groupingExpressions, aggregateExpressions, child) => + // Collect all aggregate expressions. + val allAggregates = + aggregateExpressions.flatMap(_ collect { case a: AggregateExpression => a}) + // Collect all aggregate expressions that can be computed partially. + val partialAggregates = + aggregateExpressions.flatMap(_ collect { case p: PartialAggregate => p}) + + // Only do partial aggregation if supported by all aggregate expressions. + if (allAggregates.size == partialAggregates.size) { + // Create a map of expressions to their partial evaluations for all aggregate expressions. + val partialEvaluations: Map[Long, SplitEvaluation] = + partialAggregates.map(a => (a.id, a.asPartial)).toMap + + // We need to pass all grouping expressions though so the grouping can happen a second + // time. However some of them might be unnamed so we alias them allowing them to be + // referenced in the second aggregation. + val namedGroupingExpressions: Map[Expression, NamedExpression] = groupingExpressions.map { + case n: NamedExpression => (n, n) + case other => (other, Alias(other, "PartialGroup")()) + }.toMap + + // Replace aggregations with a new expression that computes the result from the already + // computed partial evaluations and grouping values. + val rewrittenAggregateExpressions = aggregateExpressions.map(_.transformUp { + case e: Expression if partialEvaluations.contains(e.id) => + partialEvaluations(e.id).finalEvaluation + case e: Expression if namedGroupingExpressions.contains(e) => + namedGroupingExpressions(e).toAttribute + }).asInstanceOf[Seq[NamedExpression]] + + val partialComputation = + (namedGroupingExpressions.values ++ + partialEvaluations.values.flatMap(_.partialEvaluations)).toSeq + + val namedGroupingAttributes = namedGroupingExpressions.values.map(_.toAttribute).toSeq + + Some( + (namedGroupingAttributes, + rewrittenAggregateExpressions, + groupingExpressions, + partialComputation, + child)) + } else { + None + } + case _ => None + } +} + + /** * A pattern that finds joins with equality conditions that can be evaluated using equi-join. */ @@ -114,7 +184,7 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper { def unapply(plan: LogicalPlan): Option[ReturnType] = plan match { case join @ Join(left, right, joinType, condition) => - logger.debug(s"Considering join on: $condition") + logDebug(s"Considering join on: $condition") // Find equi-join predicates that can be evaluated before the join, and thus can be used // as join keys. val (joinPredicates, otherPredicates) = @@ -132,7 +202,7 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper { val rightKeys = joinKeys.map(_._2) if (joinKeys.nonEmpty) { - logger.debug(s"leftKeys:${leftKeys} | rightKeys:${rightKeys}") + logDebug(s"leftKeys:${leftKeys} | rightKeys:${rightKeys}") Some((joinType, leftKeys, rightKeys, otherPredicates.reduceOption(And), left, right)) } else { None diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 7b82e19b2e714..0988b0c6d990c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -125,51 +125,10 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy }.toSeq } - protected def generateSchemaString(schema: Seq[Attribute]): String = { - val builder = new StringBuilder - builder.append("root\n") - val prefix = " |" - schema.foreach { attribute => - val name = attribute.name - val dataType = attribute.dataType - dataType match { - case fields: StructType => - builder.append(s"$prefix-- $name: $StructType\n") - generateSchemaString(fields, s"$prefix |", builder) - case ArrayType(fields: StructType) => - builder.append(s"$prefix-- $name: $ArrayType[$StructType]\n") - generateSchemaString(fields, s"$prefix |", builder) - case ArrayType(elementType: DataType) => - builder.append(s"$prefix-- $name: $ArrayType[$elementType]\n") - case _ => builder.append(s"$prefix-- $name: $dataType\n") - } - } - - builder.toString() - } - - protected def generateSchemaString( - schema: StructType, - prefix: String, - builder: StringBuilder): StringBuilder = { - schema.fields.foreach { - case StructField(name, fields: StructType, _) => - builder.append(s"$prefix-- $name: $StructType\n") - generateSchemaString(fields, s"$prefix |", builder) - case StructField(name, ArrayType(fields: StructType), _) => - builder.append(s"$prefix-- $name: $ArrayType[$StructType]\n") - generateSchemaString(fields, s"$prefix |", builder) - case StructField(name, ArrayType(elementType: DataType), _) => - builder.append(s"$prefix-- $name: $ArrayType[$elementType]\n") - case StructField(name, fieldType: DataType, _) => - builder.append(s"$prefix-- $name: $fieldType\n") - } - - builder - } + def schema: StructType = StructType.fromAttributes(output) /** Returns the output schema in the tree format. */ - def schemaString: String = generateSchemaString(output) + def schemaString: String = schema.treeString /** Prints out the schema in the tree format */ def printSchema(): Unit = println(schemaString) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index edc37e3877c0e..278569f0cb14a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -26,6 +26,25 @@ import org.apache.spark.sql.catalyst.trees abstract class LogicalPlan extends QueryPlan[LogicalPlan] { self: Product => + /** + * Estimates of various statistics. The default estimation logic simply lazily multiplies the + * corresponding statistic produced by the children. To override this behavior, override + * `statistics` and assign it an overriden version of `Statistics`. + * + * '''NOTE''': concrete and/or overriden versions of statistics fields should pay attention to the + * performance of the implementations. The reason is that estimations might get triggered in + * performance-critical processes, such as query plan planning. + * + * @param sizeInBytes Physical size in bytes. For leaf operators this defaults to 1, otherwise it + * defaults to the product of children's `sizeInBytes`. + */ + case class Statistics( + sizeInBytes: BigInt + ) + lazy val statistics: Statistics = Statistics( + sizeInBytes = children.map(_.statistics).map(_.sizeInBytes).product + ) + /** * Returns the set of attributes that are referenced by this node * during evaluation. @@ -53,16 +72,29 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] { def childrenResolved: Boolean = !children.exists(!_.resolved) /** - * Optionally resolves the given string to a [[NamedExpression]]. The attribute is expressed as + * Optionally resolves the given string to a [[NamedExpression]] using the input from all child + * nodes of this LogicalPlan. The attribute is expressed as * as string in the following form: `[scope].AttributeName.[nested].[fields]...`. */ - def resolve(name: String): Option[NamedExpression] = { + def resolveChildren(name: String): Option[NamedExpression] = + resolve(name, children.flatMap(_.output)) + + /** + * Optionally resolves the given string to a [[NamedExpression]] based on the output of this + * LogicalPlan. The attribute is expressed as string in the following form: + * `[scope].AttributeName.[nested].[fields]...`. + */ + def resolve(name: String): Option[NamedExpression] = + resolve(name, output) + + /** Performs attribute resolution given a name and a sequence of possible attributes. */ + protected def resolve(name: String, input: Seq[Attribute]): Option[NamedExpression] = { val parts = name.split("\\.") // Collect all attributes that are output by this nodes children where either the first part // matches the name or where the first part matches the scope and the second part matches the // name. Return these matches along with any remaining parts, which represent dotted access to // struct fields. - val options = children.flatMap(_.output).flatMap { option => + val options = input.flatMap { option => // If the first part of the desired name matches a qualifier for this possible match, drop it. val remainingParts = if (option.qualifiers.contains(parts.head) && parts.size > 1) parts.drop(1) else parts @@ -70,15 +102,15 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] { } options.distinct match { - case (a, Nil) :: Nil => Some(a) // One match, no nested fields, use it. + case Seq((a, Nil)) => Some(a) // One match, no nested fields, use it. // One match, but we also need to extract the requested nested field. - case (a, nestedFields) :: Nil => + case Seq((a, nestedFields)) => a.dataType match { case StructType(fields) => Some(Alias(nestedFields.foldLeft(a: Expression)(GetField), nestedFields.last)()) case _ => None // Don't know how to resolve these field references } - case Nil => None // No matches. + case Seq() => None // No matches. case ambiguousReferences => throw new TreeNodeException( this, s"Ambiguous references to $name: ${ambiguousReferences.mkString(",")}") @@ -92,6 +124,9 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] { abstract class LeafNode extends LogicalPlan with trees.LeafNode[LogicalPlan] { self: Product => + override lazy val statistics: Statistics = + throw new UnsupportedOperationException(s"LeafNode $nodeName must implement statistics.") + // Leaf nodes by definition cannot reference any input attributes. override def references = Set.empty } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 1537de259c5b4..3cb407217c4c3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -177,7 +177,7 @@ case class LowerCaseSchema(child: LogicalPlan) extends UnaryNode { case StructType(fields) => StructType(fields.map(f => StructField(f.name.toLowerCase(), lowerCaseSchema(f.dataType), f.nullable))) - case ArrayType(elemType) => ArrayType(lowerCaseSchema(elemType)) + case ArrayType(elemType, containsNull) => ArrayType(lowerCaseSchema(elemType), containsNull) case otherType => otherType } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala index 1d5f033f0d274..481a5a4f212b2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala @@ -35,7 +35,7 @@ abstract class Command extends LeafNode { */ case class NativeCommand(cmd: String) extends Command { override def output = - Seq(BoundReference(0, AttributeReference("result", StringType, nullable = false)())) + Seq(AttributeReference("result", StringType, nullable = false)()) } /** @@ -43,8 +43,7 @@ case class NativeCommand(cmd: String) extends Command { */ case class SetCommand(key: Option[String], value: Option[String]) extends Command { override def output = Seq( - BoundReference(0, AttributeReference("key", StringType, nullable = false)()), - BoundReference(1, AttributeReference("value", StringType, nullable = false)())) + AttributeReference("", StringType, nullable = false)()) } /** @@ -53,7 +52,7 @@ case class SetCommand(key: Option[String], value: Option[String]) extends Comman */ case class ExplainCommand(plan: LogicalPlan) extends Command { override def output = - Seq(BoundReference(0, AttributeReference("plan", StringType, nullable = false)())) + Seq(AttributeReference("plan", StringType, nullable = false)()) } /** @@ -72,7 +71,7 @@ case class DescribeCommand( isExtended: Boolean) extends Command { override def output = Seq( // Column names are based on Hive. - BoundReference(0, AttributeReference("col_name", StringType, nullable = false)()), - BoundReference(1, AttributeReference("data_type", StringType, nullable = false)()), - BoundReference(2, AttributeReference("comment", StringType, nullable = false)())) + AttributeReference("col_name", StringType, nullable = false)(), + AttributeReference("data_type", StringType, nullable = false)(), + AttributeReference("comment", StringType, nullable = false)()) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/Rule.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/Rule.scala index 1076537bc7602..03414b2301e81 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/Rule.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/Rule.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.rules -import org.apache.spark.sql.Logging +import org.apache.spark.Logging import org.apache.spark.sql.catalyst.trees.TreeNode abstract class Rule[TreeType <: TreeNode[_]] extends Logging { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala index e32adb76fe146..d192b151ac1c3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala @@ -15,10 +15,9 @@ * limitations under the License. */ -package org.apache.spark.sql -package catalyst -package rules +package org.apache.spark.sql.catalyst.rules +import org.apache.spark.Logging import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.catalyst.util.sideBySide @@ -61,7 +60,7 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { case (plan, rule) => val result = rule(plan) if (!result.fastEquals(plan)) { - logger.trace( + logTrace( s""" |=== Applying Rule ${rule.ruleName} === |${sideBySide(plan.treeString, result.treeString).mkString("\n")} @@ -72,25 +71,28 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { } iteration += 1 if (iteration > batch.strategy.maxIterations) { - logger.info(s"Max iterations ($iteration) reached for batch ${batch.name}") + // Only log if this is a rule that is supposed to run more than once. + if (iteration != 2) { + logInfo(s"Max iterations (${iteration - 1}) reached for batch ${batch.name}") + } continue = false } if (curPlan.fastEquals(lastPlan)) { - logger.trace(s"Fixed point reached for batch ${batch.name} after $iteration iterations.") + logTrace(s"Fixed point reached for batch ${batch.name} after $iteration iterations.") continue = false } lastPlan = curPlan } if (!batchStartPlan.fastEquals(curPlan)) { - logger.debug( + logDebug( s""" |=== Result of Batch ${batch.name} === |${sideBySide(plan.treeString, curPlan.treeString).mkString("\n")} """.stripMargin) } else { - logger.trace(s"Batch ${batch.name} has no effect.") + logTrace(s"Batch ${batch.name} has no effect.") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/package.scala index d159ecdd5d781..d725a92c06f7b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/package.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst -import org.apache.spark.sql.Logger +import org.apache.spark.Logging /** * A library for easily manipulating trees of operators. Operators that extend TreeNode are @@ -33,7 +33,8 @@ import org.apache.spark.sql.Logger *
  • debugging support - pretty printing, easy splicing of trees, etc.
  • * */ -package object trees { +package object trees extends Logging { // Since we want tree nodes to be lightweight, we create one logger for all treenode instances. - protected val logger = Logger("catalyst.trees") + protected override def logName = "catalyst.trees" + } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala index cd4b5e9c1b529..b52ee6d3378a3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala @@ -23,16 +23,13 @@ import scala.reflect.ClassTag import scala.reflect.runtime.universe.{typeTag, TypeTag, runtimeMirror} import scala.util.parsing.combinator.RegexParsers +import org.apache.spark.sql.catalyst.ScalaReflectionLock import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression} import org.apache.spark.util.Utils /** - * A JVM-global lock that should be used to prevent thread safety issues when using things in - * scala.reflect.*. Note that Scala Reflection API is made thread-safe in 2.11, but not yet for - * 2.10.* builds. See SI-6240 for more details. + * Utility functions for working with DataTypes. */ -protected[catalyst] object ScalaReflectionLock - object DataType extends RegexParsers { protected lazy val primitiveType: Parser[DataType] = "StringType" ^^^ StringType | @@ -48,11 +45,13 @@ object DataType extends RegexParsers { "TimestampType" ^^^ TimestampType protected lazy val arrayType: Parser[DataType] = - "ArrayType" ~> "(" ~> dataType <~ ")" ^^ ArrayType + "ArrayType" ~> "(" ~> dataType ~ "," ~ boolVal <~ ")" ^^ { + case tpe ~ _ ~ containsNull => ArrayType(tpe, containsNull) + } protected lazy val mapType: Parser[DataType] = - "MapType" ~> "(" ~> dataType ~ "," ~ dataType <~ ")" ^^ { - case t1 ~ _ ~ t2 => MapType(t1, t2) + "MapType" ~> "(" ~> dataType ~ "," ~ dataType ~ "," ~ boolVal <~ ")" ^^ { + case t1 ~ _ ~ t2 ~ _ ~ valueContainsNull => MapType(t1, t2, valueContainsNull) } protected lazy val structField: Parser[StructField] = @@ -85,6 +84,21 @@ object DataType extends RegexParsers { case Success(result, _) => result case failure: NoSuccess => sys.error(s"Unsupported dataType: $asString, $failure") } + + protected[types] def buildFormattedString( + dataType: DataType, + prefix: String, + builder: StringBuilder): Unit = { + dataType match { + case array: ArrayType => + array.buildFormattedString(prefix, builder) + case struct: StructType => + struct.buildFormattedString(prefix, builder) + case map: MapType => + map.buildFormattedString(prefix, builder) + case _ => + } + } } abstract class DataType { @@ -95,49 +109,65 @@ abstract class DataType { } def isPrimitive: Boolean = false + + def simpleString: String +} + +case object NullType extends DataType { + def simpleString: String = "null" } -case object NullType extends DataType +object NativeType { + def all = Seq( + IntegerType, BooleanType, LongType, DoubleType, FloatType, ShortType, ByteType, StringType) + + def unapply(dt: DataType): Boolean = all.contains(dt) +} trait PrimitiveType extends DataType { override def isPrimitive = true } abstract class NativeType extends DataType { - type JvmType - @transient val tag: TypeTag[JvmType] - val ordering: Ordering[JvmType] + private[sql] type JvmType + @transient private[sql] val tag: TypeTag[JvmType] + private[sql] val ordering: Ordering[JvmType] - @transient val classTag = ScalaReflectionLock.synchronized { + @transient private[sql] val classTag = ScalaReflectionLock.synchronized { val mirror = runtimeMirror(Utils.getSparkClassLoader) ClassTag[JvmType](mirror.runtimeClass(tag.tpe)) } } case object StringType extends NativeType with PrimitiveType { - type JvmType = String - @transient lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } - val ordering = implicitly[Ordering[JvmType]] + private[sql] type JvmType = String + @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } + private[sql] val ordering = implicitly[Ordering[JvmType]] + def simpleString: String = "string" } case object BinaryType extends DataType with PrimitiveType { - type JvmType = Array[Byte] + private[sql] type JvmType = Array[Byte] + def simpleString: String = "binary" } case object BooleanType extends NativeType with PrimitiveType { - type JvmType = Boolean - @transient lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } - val ordering = implicitly[Ordering[JvmType]] + private[sql] type JvmType = Boolean + @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } + private[sql] val ordering = implicitly[Ordering[JvmType]] + def simpleString: String = "boolean" } case object TimestampType extends NativeType { - type JvmType = Timestamp + private[sql] type JvmType = Timestamp - @transient lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } + @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } - val ordering = new Ordering[JvmType] { + private[sql] val ordering = new Ordering[JvmType] { def compare(x: Timestamp, y: Timestamp) = x.compareTo(y) } + + def simpleString: String = "timestamp" } abstract class NumericType extends NativeType with PrimitiveType { @@ -146,7 +176,11 @@ abstract class NumericType extends NativeType with PrimitiveType { // type parameter and and add a numeric annotation (i.e., [JvmType : Numeric]). This gets // desugared by the compiler into an argument to the objects constructor. This means there is no // longer an no argument constructor and thus the JVM cannot serialize the object anymore. - val numeric: Numeric[JvmType] + private[sql] val numeric: Numeric[JvmType] +} + +object NumericType { + def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[NumericType] } /** Matcher for any expressions that evaluate to [[IntegralType]]s */ @@ -158,39 +192,43 @@ object IntegralType { } abstract class IntegralType extends NumericType { - val integral: Integral[JvmType] + private[sql] val integral: Integral[JvmType] } case object LongType extends IntegralType { - type JvmType = Long - @transient lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } - val numeric = implicitly[Numeric[Long]] - val integral = implicitly[Integral[Long]] - val ordering = implicitly[Ordering[JvmType]] + private[sql] type JvmType = Long + @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } + private[sql] val numeric = implicitly[Numeric[Long]] + private[sql] val integral = implicitly[Integral[Long]] + private[sql] val ordering = implicitly[Ordering[JvmType]] + def simpleString: String = "long" } case object IntegerType extends IntegralType { - type JvmType = Int - @transient lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } - val numeric = implicitly[Numeric[Int]] - val integral = implicitly[Integral[Int]] - val ordering = implicitly[Ordering[JvmType]] + private[sql] type JvmType = Int + @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } + private[sql] val numeric = implicitly[Numeric[Int]] + private[sql] val integral = implicitly[Integral[Int]] + private[sql] val ordering = implicitly[Ordering[JvmType]] + def simpleString: String = "integer" } case object ShortType extends IntegralType { - type JvmType = Short - @transient lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } - val numeric = implicitly[Numeric[Short]] - val integral = implicitly[Integral[Short]] - val ordering = implicitly[Ordering[JvmType]] + private[sql] type JvmType = Short + @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } + private[sql] val numeric = implicitly[Numeric[Short]] + private[sql] val integral = implicitly[Integral[Short]] + private[sql] val ordering = implicitly[Ordering[JvmType]] + def simpleString: String = "short" } case object ByteType extends IntegralType { - type JvmType = Byte - @transient lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } - val numeric = implicitly[Numeric[Byte]] - val integral = implicitly[Integral[Byte]] - val ordering = implicitly[Ordering[JvmType]] + private[sql] type JvmType = Byte + @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } + private[sql] val numeric = implicitly[Numeric[Byte]] + private[sql] val integral = implicitly[Integral[Byte]] + private[sql] val ordering = implicitly[Ordering[JvmType]] + def simpleString: String = "byte" } /** Matcher for any expressions that evaluate to [[FractionalType]]s */ @@ -201,47 +239,159 @@ object FractionalType { } } abstract class FractionalType extends NumericType { - val fractional: Fractional[JvmType] + private[sql] val fractional: Fractional[JvmType] } case object DecimalType extends FractionalType { - type JvmType = BigDecimal - @transient lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } - val numeric = implicitly[Numeric[BigDecimal]] - val fractional = implicitly[Fractional[BigDecimal]] - val ordering = implicitly[Ordering[JvmType]] + private[sql] type JvmType = BigDecimal + @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } + private[sql] val numeric = implicitly[Numeric[BigDecimal]] + private[sql] val fractional = implicitly[Fractional[BigDecimal]] + private[sql] val ordering = implicitly[Ordering[JvmType]] + def simpleString: String = "decimal" } case object DoubleType extends FractionalType { - type JvmType = Double - @transient lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } - val numeric = implicitly[Numeric[Double]] - val fractional = implicitly[Fractional[Double]] - val ordering = implicitly[Ordering[JvmType]] + private[sql] type JvmType = Double + @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } + private[sql] val numeric = implicitly[Numeric[Double]] + private[sql] val fractional = implicitly[Fractional[Double]] + private[sql] val ordering = implicitly[Ordering[JvmType]] + def simpleString: String = "double" } case object FloatType extends FractionalType { - type JvmType = Float - @transient lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } - val numeric = implicitly[Numeric[Float]] - val fractional = implicitly[Fractional[Float]] - val ordering = implicitly[Ordering[JvmType]] + private[sql] type JvmType = Float + @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } + private[sql] val numeric = implicitly[Numeric[Float]] + private[sql] val fractional = implicitly[Fractional[Float]] + private[sql] val ordering = implicitly[Ordering[JvmType]] + def simpleString: String = "float" +} + +object ArrayType { + /** Construct a [[ArrayType]] object with the given element type. The `containsNull` is false. */ + def apply(elementType: DataType): ArrayType = ArrayType(elementType, false) +} + +/** + * The data type for collections of multiple values. + * Internally these are represented as columns that contain a ``scala.collection.Seq``. + * + * @param elementType The data type of values. + * @param containsNull Indicates if values have `null` values + */ +case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataType { + private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = { + builder.append( + s"${prefix}-- element: ${elementType.simpleString} (containsNull = ${containsNull})\n") + DataType.buildFormattedString(elementType, s"$prefix |", builder) + } + + def simpleString: String = "array" } -case class ArrayType(elementType: DataType) extends DataType +/** + * A field inside a StructType. + * @param name The name of this field. + * @param dataType The data type of this field. + * @param nullable Indicates if values of this field can be `null` values. + */ +case class StructField(name: String, dataType: DataType, nullable: Boolean) { -case class StructField(name: String, dataType: DataType, nullable: Boolean) + private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = { + builder.append(s"${prefix}-- ${name}: ${dataType.simpleString} (nullable = ${nullable})\n") + DataType.buildFormattedString(dataType, s"$prefix |", builder) + } +} object StructType { - def fromAttributes(attributes: Seq[Attribute]): StructType = { + protected[sql] def fromAttributes(attributes: Seq[Attribute]): StructType = StructType(attributes.map(a => StructField(a.name, a.dataType, a.nullable))) - } - // def apply(fields: Seq[StructField]) = new StructType(fields.toIndexedSeq) + private def validateFields(fields: Seq[StructField]): Boolean = + fields.map(field => field.name).distinct.size == fields.size } case class StructType(fields: Seq[StructField]) extends DataType { - def toAttributes = fields.map(f => AttributeReference(f.name, f.dataType, f.nullable)()) + require(StructType.validateFields(fields), "Found fields with the same name.") + + /** + * Returns all field names in a [[Seq]]. + */ + lazy val fieldNames: Seq[String] = fields.map(_.name) + private lazy val fieldNamesSet: Set[String] = fieldNames.toSet + private lazy val nameToField: Map[String, StructField] = fields.map(f => f.name -> f).toMap + /** + * Extracts a [[StructField]] of the given name. If the [[StructType]] object does not + * have a name matching the given name, `null` will be returned. + */ + def apply(name: String): StructField = { + nameToField.get(name).getOrElse( + throw new IllegalArgumentException(s"Field ${name} does not exist.")) + } + + /** + * Returns a [[StructType]] containing [[StructField]]s of the given names. + * Those names which do not have matching fields will be ignored. + */ + def apply(names: Set[String]): StructType = { + val nonExistFields = names -- fieldNamesSet + if (!nonExistFields.isEmpty) { + throw new IllegalArgumentException( + s"Field ${nonExistFields.mkString(",")} does not exist.") + } + // Preserve the original order of fields. + StructType(fields.filter(f => names.contains(f.name))) + } + + protected[sql] def toAttributes = + fields.map(f => AttributeReference(f.name, f.dataType, f.nullable)()) + + def treeString: String = { + val builder = new StringBuilder + builder.append("root\n") + val prefix = " |" + fields.foreach(field => field.buildFormattedString(prefix, builder)) + + builder.toString() + } + + def printTreeString(): Unit = println(treeString) + + private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = { + fields.foreach(field => field.buildFormattedString(prefix, builder)) + } + + def simpleString: String = "struct" } -case class MapType(keyType: DataType, valueType: DataType) extends DataType +object MapType { + /** + * Construct a [[MapType]] object with the given key type and value type. + * The `valueContainsNull` is true. + */ + def apply(keyType: DataType, valueType: DataType): MapType = + MapType(keyType: DataType, valueType: DataType, true) +} + +/** + * The data type for Maps. Keys in a map are not allowed to have `null` values. + * @param keyType The data type of map keys. + * @param valueType The data type of map values. + * @param valueContainsNull Indicates if map values have `null` values. + */ +case class MapType( + keyType: DataType, + valueType: DataType, + valueContainsNull: Boolean) extends DataType { + private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = { + builder.append(s"${prefix}-- key: ${keyType.simpleString}\n") + builder.append(s"${prefix}-- value: ${valueType.simpleString} " + + s"(valueContainsNull = ${valueContainsNull})\n") + DataType.buildFormattedString(keyType, s"$prefix |", builder) + DataType.buildFormattedString(valueType, s"$prefix |", builder) + } + + def simpleString: String = "map" +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index c0438dbe52a47..e75373d5a74a7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -17,11 +17,11 @@ package org.apache.spark.sql.catalyst +import java.math.BigInteger import java.sql.Timestamp import org.scalatest.FunSuite -import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.types._ case class PrimitiveData( @@ -148,4 +148,68 @@ class ScalaReflectionSuite extends FunSuite { StructField("_2", StringType, nullable = true))), nullable = true)) } + + test("get data type of a value") { + // BooleanType + assert(BooleanType === typeOfObject(true)) + assert(BooleanType === typeOfObject(false)) + + // BinaryType + assert(BinaryType === typeOfObject("string".getBytes)) + + // StringType + assert(StringType === typeOfObject("string")) + + // ByteType + assert(ByteType === typeOfObject(127.toByte)) + + // ShortType + assert(ShortType === typeOfObject(32767.toShort)) + + // IntegerType + assert(IntegerType === typeOfObject(2147483647)) + + // LongType + assert(LongType === typeOfObject(9223372036854775807L)) + + // FloatType + assert(FloatType === typeOfObject(3.4028235E38.toFloat)) + + // DoubleType + assert(DoubleType === typeOfObject(1.7976931348623157E308)) + + // DecimalType + assert(DecimalType === typeOfObject(BigDecimal("1.7976931348623157E318"))) + + // TimestampType + assert(TimestampType === typeOfObject(java.sql.Timestamp.valueOf("2014-07-25 10:26:00"))) + + // NullType + assert(NullType === typeOfObject(null)) + + def typeOfObject1: PartialFunction[Any, DataType] = typeOfObject orElse { + case value: java.math.BigInteger => DecimalType + case value: java.math.BigDecimal => DecimalType + case _ => StringType + } + + assert(DecimalType === typeOfObject1( + new BigInteger("92233720368547758070"))) + assert(DecimalType === typeOfObject1( + new java.math.BigDecimal("1.7976931348623157E318"))) + assert(StringType === typeOfObject1(BigInt("92233720368547758070"))) + + def typeOfObject2: PartialFunction[Any, DataType] = typeOfObject orElse { + case value: java.math.BigInteger => DecimalType + } + + intercept[MatchError](typeOfObject2(BigInt("92233720368547758070"))) + + def typeOfObject3: PartialFunction[Any, DataType] = typeOfObject orElse { + case c: Seq[_] => ArrayType(typeOfObject3(c.head)) + } + + assert(ArrayType(IntegerType) === typeOfObject3(Seq(1, 2, 3))) + assert(ArrayType(ArrayType(IntegerType)) === typeOfObject3(Seq(Seq(1,2,3)))) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index 58f8c341e6676..999c9fff38d60 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -29,7 +29,11 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ class ExpressionEvaluationSuite extends FunSuite { test("literals") { - assert((Literal(1) + Literal(1)).eval(null) === 2) + checkEvaluation(Literal(1), 1) + checkEvaluation(Literal(true), true) + checkEvaluation(Literal(0L), 0L) + checkEvaluation(Literal("test"), "test") + checkEvaluation(Literal(1) + Literal(1), 2) } /** @@ -61,10 +65,8 @@ class ExpressionEvaluationSuite extends FunSuite { test("3VL Not") { notTrueTable.foreach { case (v, answer) => - val expr = ! Literal(v, BooleanType) - val result = expr.eval(null) - if (result != answer) - fail(s"$expr should not evaluate to $result, expected: $answer") } + checkEvaluation(!Literal(v, BooleanType), answer) + } } booleanLogicTest("AND", _ && _, @@ -127,6 +129,13 @@ class ExpressionEvaluationSuite extends FunSuite { } } + test("IN") { + checkEvaluation(In(Literal(1), Seq(Literal(1), Literal(2))), true) + checkEvaluation(In(Literal(2), Seq(Literal(1), Literal(2))), true) + checkEvaluation(In(Literal(3), Seq(Literal(1), Literal(2))), false) + checkEvaluation(In(Literal(1), Seq(Literal(1), Literal(2))) && In(Literal(2), Seq(Literal(1), Literal(2))), true) + } + test("LIKE literal Regular Expression") { checkEvaluation(Literal(null, StringType).like("a"), null) checkEvaluation(Literal("a", StringType).like(Literal(null, StringType)), null) @@ -232,21 +241,21 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation(Literal(false) cast IntegerType, 0) checkEvaluation(Cast(Literal(1) cast BooleanType, IntegerType), 1) checkEvaluation(Cast(Literal(0) cast BooleanType, IntegerType), 0) - checkEvaluation("23" cast DoubleType, 23) + checkEvaluation("23" cast DoubleType, 23d) checkEvaluation("23" cast IntegerType, 23) - checkEvaluation("23" cast FloatType, 23) - checkEvaluation("23" cast DecimalType, 23) - checkEvaluation("23" cast ByteType, 23) - checkEvaluation("23" cast ShortType, 23) + checkEvaluation("23" cast FloatType, 23f) + checkEvaluation("23" cast DecimalType, 23: BigDecimal) + checkEvaluation("23" cast ByteType, 23.toByte) + checkEvaluation("23" cast ShortType, 23.toShort) checkEvaluation("2012-12-11" cast DoubleType, null) checkEvaluation(Literal(123) cast IntegerType, 123) - checkEvaluation(Literal(23d) + Cast(true, DoubleType), 24) + checkEvaluation(Literal(23d) + Cast(true, DoubleType), 24d) checkEvaluation(Literal(23) + Cast(true, IntegerType), 24) - checkEvaluation(Literal(23f) + Cast(true, FloatType), 24) - checkEvaluation(Literal(BigDecimal(23)) + Cast(true, DecimalType), 24) - checkEvaluation(Literal(23.toByte) + Cast(true, ByteType), 24) - checkEvaluation(Literal(23.toShort) + Cast(true, ShortType), 24) + checkEvaluation(Literal(23f) + Cast(true, FloatType), 24f) + checkEvaluation(Literal(BigDecimal(23)) + Cast(true, DecimalType), 24: BigDecimal) + checkEvaluation(Literal(23.toByte) + Cast(true, ByteType), 24.toByte) + checkEvaluation(Literal(23.toShort) + Cast(true, ShortType), 24.toShort) intercept[Exception] {evaluate(Literal(1) cast BinaryType, null)} @@ -391,21 +400,21 @@ class ExpressionEvaluationSuite extends FunSuite { val typeMap = MapType(StringType, StringType) val typeArray = ArrayType(StringType) - checkEvaluation(GetItem(BoundReference(3, AttributeReference("c", typeMap)()), + checkEvaluation(GetItem(BoundReference(3, typeMap, true), Literal("aa")), "bb", row) checkEvaluation(GetItem(Literal(null, typeMap), Literal("aa")), null, row) checkEvaluation(GetItem(Literal(null, typeMap), Literal(null, StringType)), null, row) - checkEvaluation(GetItem(BoundReference(3, AttributeReference("c", typeMap)()), + checkEvaluation(GetItem(BoundReference(3, typeMap, true), Literal(null, StringType)), null, row) - checkEvaluation(GetItem(BoundReference(4, AttributeReference("c", typeArray)()), + checkEvaluation(GetItem(BoundReference(4, typeArray, true), Literal(1)), "bb", row) checkEvaluation(GetItem(Literal(null, typeArray), Literal(1)), null, row) checkEvaluation(GetItem(Literal(null, typeArray), Literal(null, IntegerType)), null, row) - checkEvaluation(GetItem(BoundReference(4, AttributeReference("c", typeArray)()), + checkEvaluation(GetItem(BoundReference(4, typeArray, true), Literal(null, IntegerType)), null, row) - checkEvaluation(GetField(BoundReference(2, AttributeReference("c", typeS)()), "a"), "aa", row) + checkEvaluation(GetField(BoundReference(2, typeS, nullable = true), "a"), "aa", row) checkEvaluation(GetField(Literal(null, typeS), "a"), null, row) val typeS_notNullable = StructType( @@ -413,10 +422,8 @@ class ExpressionEvaluationSuite extends FunSuite { :: StructField("b", StringType, nullable = false) :: Nil ) - assert(GetField(BoundReference(2, - AttributeReference("c", typeS)()), "a").nullable === true) - assert(GetField(BoundReference(2, - AttributeReference("c", typeS_notNullable, nullable = false)()), "a").nullable === false) + assert(GetField(BoundReference(2,typeS, nullable = true), "a").nullable === true) + assert(GetField(BoundReference(2, typeS_notNullable, nullable = false), "a").nullable === false) assert(GetField(Literal(null, typeS), "a").nullable === true) assert(GetField(Literal(null, typeS_notNullable), "a").nullable === true) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala new file mode 100644 index 0000000000000..245a2e148030c --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen._ + +/** + * Overrides our expression evaluation tests to use code generation for evaluation. + */ +class GeneratedEvaluationSuite extends ExpressionEvaluationSuite { + override def checkEvaluation( + expression: Expression, + expected: Any, + inputRow: Row = EmptyRow): Unit = { + val plan = try { + GenerateMutableProjection(Alias(expression, s"Optimized($expression)")() :: Nil)() + } catch { + case e: Throwable => + val evaluated = GenerateProjection.expressionEvaluator(expression) + fail( + s""" + |Code generation of $expression failed: + |${evaluated.code.mkString("\n")} + |$e + """.stripMargin) + } + + val actual = plan(inputRow).apply(0) + if(actual != expected) { + val input = if(inputRow == EmptyRow) "" else s", input: $inputRow" + fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expected$input") + } + } + + + test("multithreaded eval") { + import scala.concurrent._ + import ExecutionContext.Implicits.global + import scala.concurrent.duration._ + + val futures = (1 to 20).map { _ => + future { + GeneratePredicate(EqualTo(Literal(1), Literal(1))) + GenerateProjection(EqualTo(Literal(1), Literal(1)) :: Nil) + GenerateMutableProjection(EqualTo(Literal(1), Literal(1)) :: Nil) + GenerateOrdering(Add(Literal(1), Literal(1)).asc :: Nil) + } + } + + futures.foreach(Await.result(_, 10.seconds)) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala new file mode 100644 index 0000000000000..887aabb1d5fb4 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen._ + +/** + * Overrides our expression evaluation tests to use generated code on mutable rows. + */ +class GeneratedMutableEvaluationSuite extends ExpressionEvaluationSuite { + override def checkEvaluation( + expression: Expression, + expected: Any, + inputRow: Row = EmptyRow): Unit = { + lazy val evaluated = GenerateProjection.expressionEvaluator(expression) + + val plan = try { + GenerateProjection(Alias(expression, s"Optimized($expression)")() :: Nil) + } catch { + case e: Throwable => + fail( + s""" + |Code generation of $expression failed: + |${evaluated.code.mkString("\n")} + |$e + """.stripMargin) + } + + val actual = plan(inputRow) + val expectedRow = new GenericRow(Array[Any](expected)) + if (actual.hashCode() != expectedRow.hashCode()) { + fail( + s""" + |Mismatched hashCodes for values: $actual, $expectedRow + |Hash Codes: ${actual.hashCode()} != ${expectedRow.hashCode()} + |${evaluated.code.mkString("\n")} + """.stripMargin) + } + if (actual != expectedRow) { + val input = if(inputRow == EmptyRow) "" else s", input: $inputRow" + fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expected$input") + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala index 4896f1b955f01..e2ae0d25db1a5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala @@ -27,9 +27,9 @@ class CombiningLimitsSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = - Batch("Combine Limit", FixedPoint(2), + Batch("Combine Limit", FixedPoint(10), CombineLimits) :: - Batch("Constant Folding", FixedPoint(3), + Batch("Constant Folding", FixedPoint(10), NullPropagation, ConstantFolding, BooleanSimplification) :: Nil diff --git a/sql/core/pom.xml b/sql/core/pom.xml index c309c43804d97..c8016e41256d5 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -32,7 +32,7 @@ Spark Project SQL http://spark.apache.org/ - sql + sql @@ -68,6 +68,11 @@ jackson-databind 2.3.0
    + + junit + junit + test + org.scalatest scalatest_${scala.binary.version} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/ArrayType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/ArrayType.java new file mode 100644 index 0000000000000..b73a371e93001 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/ArrayType.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +/** + * The data type representing Lists. + * An ArrayType object comprises two fields, {@code DataType elementType} and + * {@code boolean containsNull}. The field of {@code elementType} is used to specify the type of + * array elements. The field of {@code containsNull} is used to specify if the array has + * {@code null} values. + * + * To create an {@link ArrayType}, + * {@link DataType#createArrayType(DataType)} or + * {@link DataType#createArrayType(DataType, boolean)} + * should be used. + */ +public class ArrayType extends DataType { + private DataType elementType; + private boolean containsNull; + + protected ArrayType(DataType elementType, boolean containsNull) { + this.elementType = elementType; + this.containsNull = containsNull; + } + + public DataType getElementType() { + return elementType; + } + + public boolean isContainsNull() { + return containsNull; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + ArrayType arrayType = (ArrayType) o; + + if (containsNull != arrayType.containsNull) return false; + if (!elementType.equals(arrayType.elementType)) return false; + + return true; + } + + @Override + public int hashCode() { + int result = elementType.hashCode(); + result = 31 * result + (containsNull ? 1 : 0); + return result; + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/BinaryType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/BinaryType.java new file mode 100644 index 0000000000000..7daad60f62a0b --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/BinaryType.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +/** + * The data type representing byte[] values. + * + * {@code BinaryType} is represented by the singleton object {@link DataType#BinaryType}. + */ +public class BinaryType extends DataType { + protected BinaryType() {} +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/BooleanType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/BooleanType.java new file mode 100644 index 0000000000000..5a1f52725631b --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/BooleanType.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +/** + * The data type representing boolean and Boolean values. + * + * {@code BooleanType} is represented by the singleton object {@link DataType#BooleanType}. + */ +public class BooleanType extends DataType { + protected BooleanType() {} +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/ByteType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/ByteType.java new file mode 100644 index 0000000000000..e5cdf06b21bbe --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/ByteType.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +/** + * The data type representing byte and Byte values. + * + * {@code ByteType} is represented by the singleton object {@link DataType#ByteType}. + */ +public class ByteType extends DataType { + protected ByteType() {} +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/DataType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/DataType.java new file mode 100644 index 0000000000000..3eccddef88134 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/DataType.java @@ -0,0 +1,190 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +/** + * The base type of all Spark SQL data types. + * + * To get/create specific data type, users should use singleton objects and factory methods + * provided by this class. + */ +public abstract class DataType { + + /** + * Gets the StringType object. + */ + public static final StringType StringType = new StringType(); + + /** + * Gets the BinaryType object. + */ + public static final BinaryType BinaryType = new BinaryType(); + + /** + * Gets the BooleanType object. + */ + public static final BooleanType BooleanType = new BooleanType(); + + /** + * Gets the TimestampType object. + */ + public static final TimestampType TimestampType = new TimestampType(); + + /** + * Gets the DecimalType object. + */ + public static final DecimalType DecimalType = new DecimalType(); + + /** + * Gets the DoubleType object. + */ + public static final DoubleType DoubleType = new DoubleType(); + + /** + * Gets the FloatType object. + */ + public static final FloatType FloatType = new FloatType(); + + /** + * Gets the ByteType object. + */ + public static final ByteType ByteType = new ByteType(); + + /** + * Gets the IntegerType object. + */ + public static final IntegerType IntegerType = new IntegerType(); + + /** + * Gets the LongType object. + */ + public static final LongType LongType = new LongType(); + + /** + * Gets the ShortType object. + */ + public static final ShortType ShortType = new ShortType(); + + /** + * Creates an ArrayType by specifying the data type of elements ({@code elementType}). + * The field of {@code containsNull} is set to {@code false}. + */ + public static ArrayType createArrayType(DataType elementType) { + if (elementType == null) { + throw new IllegalArgumentException("elementType should not be null."); + } + + return new ArrayType(elementType, false); + } + + /** + * Creates an ArrayType by specifying the data type of elements ({@code elementType}) and + * whether the array contains null values ({@code containsNull}). + */ + public static ArrayType createArrayType(DataType elementType, boolean containsNull) { + if (elementType == null) { + throw new IllegalArgumentException("elementType should not be null."); + } + + return new ArrayType(elementType, containsNull); + } + + /** + * Creates a MapType by specifying the data type of keys ({@code keyType}) and values + * ({@code keyType}). The field of {@code valueContainsNull} is set to {@code true}. + */ + public static MapType createMapType(DataType keyType, DataType valueType) { + if (keyType == null) { + throw new IllegalArgumentException("keyType should not be null."); + } + if (valueType == null) { + throw new IllegalArgumentException("valueType should not be null."); + } + + return new MapType(keyType, valueType, true); + } + + /** + * Creates a MapType by specifying the data type of keys ({@code keyType}), the data type of + * values ({@code keyType}), and whether values contain any null value + * ({@code valueContainsNull}). + */ + public static MapType createMapType( + DataType keyType, + DataType valueType, + boolean valueContainsNull) { + if (keyType == null) { + throw new IllegalArgumentException("keyType should not be null."); + } + if (valueType == null) { + throw new IllegalArgumentException("valueType should not be null."); + } + + return new MapType(keyType, valueType, valueContainsNull); + } + + /** + * Creates a StructField by specifying the name ({@code name}), data type ({@code dataType}) and + * whether values of this field can be null values ({@code nullable}). + */ + public static StructField createStructField(String name, DataType dataType, boolean nullable) { + if (name == null) { + throw new IllegalArgumentException("name should not be null."); + } + if (dataType == null) { + throw new IllegalArgumentException("dataType should not be null."); + } + + return new StructField(name, dataType, nullable); + } + + /** + * Creates a StructType with the given list of StructFields ({@code fields}). + */ + public static StructType createStructType(List fields) { + return createStructType(fields.toArray(new StructField[0])); + } + + /** + * Creates a StructType with the given StructField array ({@code fields}). + */ + public static StructType createStructType(StructField[] fields) { + if (fields == null) { + throw new IllegalArgumentException("fields should not be null."); + } + Set distinctNames = new HashSet(); + for (StructField field: fields) { + if (field == null) { + throw new IllegalArgumentException( + "fields should not contain any null."); + } + + distinctNames.add(field.getName()); + } + if (distinctNames.size() != fields.length) { + throw new IllegalArgumentException("fields should have distinct names."); + } + + return new StructType(fields); + } + +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/DecimalType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/DecimalType.java new file mode 100644 index 0000000000000..bc54c078d7a4e --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/DecimalType.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +/** + * The data type representing java.math.BigDecimal values. + * + * {@code DecimalType} is represented by the singleton object {@link DataType#DecimalType}. + */ +public class DecimalType extends DataType { + protected DecimalType() {} +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/DoubleType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/DoubleType.java new file mode 100644 index 0000000000000..f0060d0bcf9f5 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/DoubleType.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +/** + * The data type representing double and Double values. + * + * {@code DoubleType} is represented by the singleton object {@link DataType#DoubleType}. + */ +public class DoubleType extends DataType { + protected DoubleType() {} +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/FloatType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/FloatType.java new file mode 100644 index 0000000000000..4a6a37f69176a --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/FloatType.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +/** + * The data type representing float and Float values. + * + * {@code FloatType} is represented by the singleton object {@link DataType#FloatType}. + */ +public class FloatType extends DataType { + protected FloatType() {} +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/IntegerType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/IntegerType.java new file mode 100644 index 0000000000000..bfd70490bbbbb --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/IntegerType.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +/** + * The data type representing int and Integer values. + * + * {@code IntegerType} is represented by the singleton object {@link DataType#IntegerType}. + */ +public class IntegerType extends DataType { + protected IntegerType() {} +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/LongType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/LongType.java new file mode 100644 index 0000000000000..af13a46eb165c --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/LongType.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +/** + * The data type representing long and Long values. + * + * {@code LongType} is represented by the singleton object {@link DataType#LongType}. + */ +public class LongType extends DataType { + protected LongType() {} +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/MapType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/MapType.java new file mode 100644 index 0000000000000..063e6b34abc48 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/MapType.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +/** + * The data type representing Maps. A MapType object comprises two fields, + * {@code DataType keyType}, {@code DataType valueType}, and {@code boolean valueContainsNull}. + * The field of {@code keyType} is used to specify the type of keys in the map. + * The field of {@code valueType} is used to specify the type of values in the map. + * The field of {@code valueContainsNull} is used to specify if map values have + * {@code null} values. + * For values of a MapType column, keys are not allowed to have {@code null} values. + * + * To create a {@link MapType}, + * {@link DataType#createMapType(DataType, DataType)} or + * {@link DataType#createMapType(DataType, DataType, boolean)} + * should be used. + */ +public class MapType extends DataType { + private DataType keyType; + private DataType valueType; + private boolean valueContainsNull; + + protected MapType(DataType keyType, DataType valueType, boolean valueContainsNull) { + this.keyType = keyType; + this.valueType = valueType; + this.valueContainsNull = valueContainsNull; + } + + public DataType getKeyType() { + return keyType; + } + + public DataType getValueType() { + return valueType; + } + + public boolean isValueContainsNull() { + return valueContainsNull; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + MapType mapType = (MapType) o; + + if (valueContainsNull != mapType.valueContainsNull) return false; + if (!keyType.equals(mapType.keyType)) return false; + if (!valueType.equals(mapType.valueType)) return false; + + return true; + } + + @Override + public int hashCode() { + int result = keyType.hashCode(); + result = 31 * result + valueType.hashCode(); + result = 31 * result + (valueContainsNull ? 1 : 0); + return result; + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/ShortType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/ShortType.java new file mode 100644 index 0000000000000..7d7604b4e3d2d --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/ShortType.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +/** + * The data type representing short and Short values. + * + * {@code ShortType} is represented by the singleton object {@link DataType#ShortType}. + */ +public class ShortType extends DataType { + protected ShortType() {} +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/StringType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/StringType.java new file mode 100644 index 0000000000000..f4ba0c07c9c6e --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/StringType.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +/** + * The data type representing String values. + * + * {@code StringType} is represented by the singleton object {@link DataType#StringType}. + */ +public class StringType extends DataType { + protected StringType() {} +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/StructField.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/StructField.java new file mode 100644 index 0000000000000..b48e2a2c5f953 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/StructField.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +/** + * A StructField object represents a field in a StructType object. + * A StructField object comprises three fields, {@code String name}, {@code DataType dataType}, + * and {@code boolean nullable}. The field of {@code name} is the name of a StructField. + * The field of {@code dataType} specifies the data type of a StructField. + * The field of {@code nullable} specifies if values of a StructField can contain {@code null} + * values. + * + * To create a {@link StructField}, + * {@link DataType#createStructField(String, DataType, boolean)} + * should be used. + */ +public class StructField { + private String name; + private DataType dataType; + private boolean nullable; + + protected StructField(String name, DataType dataType, boolean nullable) { + this.name = name; + this.dataType = dataType; + this.nullable = nullable; + } + + public String getName() { + return name; + } + + public DataType getDataType() { + return dataType; + } + + public boolean isNullable() { + return nullable; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + StructField that = (StructField) o; + + if (nullable != that.nullable) return false; + if (!dataType.equals(that.dataType)) return false; + if (!name.equals(that.name)) return false; + + return true; + } + + @Override + public int hashCode() { + int result = name.hashCode(); + result = 31 * result + dataType.hashCode(); + result = 31 * result + (nullable ? 1 : 0); + return result; + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/StructType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/StructType.java new file mode 100644 index 0000000000000..a4b501efd9a10 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/StructType.java @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +import java.util.Arrays; + +/** + * The data type representing Rows. + * A StructType object comprises an array of StructFields. + * + * To create an {@link StructType}, + * {@link DataType#createStructType(java.util.List)} or + * {@link DataType#createStructType(StructField[])} + * should be used. + */ +public class StructType extends DataType { + private StructField[] fields; + + protected StructType(StructField[] fields) { + this.fields = fields; + } + + public StructField[] getFields() { + return fields; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + StructType that = (StructType) o; + + if (!Arrays.equals(fields, that.fields)) return false; + + return true; + } + + @Override + public int hashCode() { + return Arrays.hashCode(fields); + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/TimestampType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/TimestampType.java new file mode 100644 index 0000000000000..06d44c731cdfe --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/TimestampType.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +/** + * The data type representing java.sql.Timestamp values. + * + * {@code TimestampType} is represented by the singleton object {@link DataType#TimestampType}. + */ +public class TimestampType extends DataType { + protected TimestampType() {} +} diff --git a/examples/src/main/scala/org/apache/spark/examples/pythonconverters/HBaseConverter.scala b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF1.java similarity index 63% rename from examples/src/main/scala/org/apache/spark/examples/pythonconverters/HBaseConverter.scala rename to sql/core/src/main/java/org/apache/spark/sql/api/java/UDF1.java index 42ae960bd64a1..ef959e35e1027 100644 --- a/examples/src/main/scala/org/apache/spark/examples/pythonconverters/HBaseConverter.scala +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF1.java @@ -15,19 +15,18 @@ * limitations under the License. */ -package org.apache.spark.examples.pythonconverters +package org.apache.spark.sql.api.java; -import org.apache.spark.api.python.Converter -import org.apache.hadoop.hbase.client.Result -import org.apache.hadoop.hbase.util.Bytes +import java.io.Serializable; + +// ************************************************** +// THIS FILE IS AUTOGENERATED BY CODE IN +// org.apache.spark.sql.api.java.FunctionRegistration +// ************************************************** /** - * Implementation of [[org.apache.spark.api.python.Converter]] that converts a HBase Result - * to a String + * A Spark SQL UDF that has 1 arguments. */ -class HBaseConverter extends Converter[Any, String] { - override def convert(obj: Any): String = { - val result = obj.asInstanceOf[Result] - Bytes.toStringBinary(result.value()) - } +public interface UDF1 extends Serializable { + public R call(T1 t1) throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF10.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF10.java new file mode 100644 index 0000000000000..96ab3a96c3d5e --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF10.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +import java.io.Serializable; + +// ************************************************** +// THIS FILE IS AUTOGENERATED BY CODE IN +// org.apache.spark.sql.api.java.FunctionRegistration +// ************************************************** + +/** + * A Spark SQL UDF that has 10 arguments. + */ +public interface UDF10 extends Serializable { + public R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10) throws Exception; +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF11.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF11.java new file mode 100644 index 0000000000000..58ae8edd6d817 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF11.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +import java.io.Serializable; + +// ************************************************** +// THIS FILE IS AUTOGENERATED BY CODE IN +// org.apache.spark.sql.api.java.FunctionRegistration +// ************************************************** + +/** + * A Spark SQL UDF that has 11 arguments. + */ +public interface UDF11 extends Serializable { + public R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11) throws Exception; +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF12.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF12.java new file mode 100644 index 0000000000000..d9da0f6eddd94 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF12.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +import java.io.Serializable; + +// ************************************************** +// THIS FILE IS AUTOGENERATED BY CODE IN +// org.apache.spark.sql.api.java.FunctionRegistration +// ************************************************** + +/** + * A Spark SQL UDF that has 12 arguments. + */ +public interface UDF12 extends Serializable { + public R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12) throws Exception; +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF13.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF13.java new file mode 100644 index 0000000000000..095fc1a8076b5 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF13.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +import java.io.Serializable; + +// ************************************************** +// THIS FILE IS AUTOGENERATED BY CODE IN +// org.apache.spark.sql.api.java.FunctionRegistration +// ************************************************** + +/** + * A Spark SQL UDF that has 13 arguments. + */ +public interface UDF13 extends Serializable { + public R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12, T13 t13) throws Exception; +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF14.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF14.java new file mode 100644 index 0000000000000..eb27eaa180086 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF14.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +import java.io.Serializable; + +// ************************************************** +// THIS FILE IS AUTOGENERATED BY CODE IN +// org.apache.spark.sql.api.java.FunctionRegistration +// ************************************************** + +/** + * A Spark SQL UDF that has 14 arguments. + */ +public interface UDF14 extends Serializable { + public R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12, T13 t13, T14 t14) throws Exception; +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF15.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF15.java new file mode 100644 index 0000000000000..1fbcff56332b6 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF15.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +import java.io.Serializable; + +// ************************************************** +// THIS FILE IS AUTOGENERATED BY CODE IN +// org.apache.spark.sql.api.java.FunctionRegistration +// ************************************************** + +/** + * A Spark SQL UDF that has 15 arguments. + */ +public interface UDF15 extends Serializable { + public R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12, T13 t13, T14 t14, T15 t15) throws Exception; +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF16.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF16.java new file mode 100644 index 0000000000000..1133561787a69 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF16.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +import java.io.Serializable; + +// ************************************************** +// THIS FILE IS AUTOGENERATED BY CODE IN +// org.apache.spark.sql.api.java.FunctionRegistration +// ************************************************** + +/** + * A Spark SQL UDF that has 16 arguments. + */ +public interface UDF16 extends Serializable { + public R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12, T13 t13, T14 t14, T15 t15, T16 t16) throws Exception; +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF17.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF17.java new file mode 100644 index 0000000000000..dfae7922c9b63 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF17.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +import java.io.Serializable; + +// ************************************************** +// THIS FILE IS AUTOGENERATED BY CODE IN +// org.apache.spark.sql.api.java.FunctionRegistration +// ************************************************** + +/** + * A Spark SQL UDF that has 17 arguments. + */ +public interface UDF17 extends Serializable { + public R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12, T13 t13, T14 t14, T15 t15, T16 t16, T17 t17) throws Exception; +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF18.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF18.java new file mode 100644 index 0000000000000..e9d1c6d52d4ea --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF18.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +import java.io.Serializable; + +// ************************************************** +// THIS FILE IS AUTOGENERATED BY CODE IN +// org.apache.spark.sql.api.java.FunctionRegistration +// ************************************************** + +/** + * A Spark SQL UDF that has 18 arguments. + */ +public interface UDF18 extends Serializable { + public R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12, T13 t13, T14 t14, T15 t15, T16 t16, T17 t17, T18 t18) throws Exception; +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF19.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF19.java new file mode 100644 index 0000000000000..46b9d2d3c9457 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF19.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +import java.io.Serializable; + +// ************************************************** +// THIS FILE IS AUTOGENERATED BY CODE IN +// org.apache.spark.sql.api.java.FunctionRegistration +// ************************************************** + +/** + * A Spark SQL UDF that has 19 arguments. + */ +public interface UDF19 extends Serializable { + public R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12, T13 t13, T14 t14, T15 t15, T16 t16, T17 t17, T18 t18, T19 t19) throws Exception; +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF2.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF2.java new file mode 100644 index 0000000000000..cd3fde8da419e --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF2.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +import java.io.Serializable; + +// ************************************************** +// THIS FILE IS AUTOGENERATED BY CODE IN +// org.apache.spark.sql.api.java.FunctionRegistration +// ************************************************** + +/** + * A Spark SQL UDF that has 2 arguments. + */ +public interface UDF2 extends Serializable { + public R call(T1 t1, T2 t2) throws Exception; +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF20.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF20.java new file mode 100644 index 0000000000000..113d3d26be4a7 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF20.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +import java.io.Serializable; + +// ************************************************** +// THIS FILE IS AUTOGENERATED BY CODE IN +// org.apache.spark.sql.api.java.FunctionRegistration +// ************************************************** + +/** + * A Spark SQL UDF that has 20 arguments. + */ +public interface UDF20 extends Serializable { + public R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12, T13 t13, T14 t14, T15 t15, T16 t16, T17 t17, T18 t18, T19 t19, T20 t20) throws Exception; +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF21.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF21.java new file mode 100644 index 0000000000000..74118f2cf8da7 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF21.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +import java.io.Serializable; + +// ************************************************** +// THIS FILE IS AUTOGENERATED BY CODE IN +// org.apache.spark.sql.api.java.FunctionRegistration +// ************************************************** + +/** + * A Spark SQL UDF that has 21 arguments. + */ +public interface UDF21 extends Serializable { + public R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12, T13 t13, T14 t14, T15 t15, T16 t16, T17 t17, T18 t18, T19 t19, T20 t20, T21 t21) throws Exception; +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF22.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF22.java new file mode 100644 index 0000000000000..0e7cc40be45ec --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF22.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +import java.io.Serializable; + +// ************************************************** +// THIS FILE IS AUTOGENERATED BY CODE IN +// org.apache.spark.sql.api.java.FunctionRegistration +// ************************************************** + +/** + * A Spark SQL UDF that has 22 arguments. + */ +public interface UDF22 extends Serializable { + public R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12, T13 t13, T14 t14, T15 t15, T16 t16, T17 t17, T18 t18, T19 t19, T20 t20, T21 t21, T22 t22) throws Exception; +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF3.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF3.java new file mode 100644 index 0000000000000..6a880f16be47a --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF3.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +import java.io.Serializable; + +// ************************************************** +// THIS FILE IS AUTOGENERATED BY CODE IN +// org.apache.spark.sql.api.java.FunctionRegistration +// ************************************************** + +/** + * A Spark SQL UDF that has 3 arguments. + */ +public interface UDF3 extends Serializable { + public R call(T1 t1, T2 t2, T3 t3) throws Exception; +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF4.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF4.java new file mode 100644 index 0000000000000..fcad2febb18e6 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF4.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +import java.io.Serializable; + +// ************************************************** +// THIS FILE IS AUTOGENERATED BY CODE IN +// org.apache.spark.sql.api.java.FunctionRegistration +// ************************************************** + +/** + * A Spark SQL UDF that has 4 arguments. + */ +public interface UDF4 extends Serializable { + public R call(T1 t1, T2 t2, T3 t3, T4 t4) throws Exception; +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF5.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF5.java new file mode 100644 index 0000000000000..ce0cef43a2144 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF5.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +import java.io.Serializable; + +// ************************************************** +// THIS FILE IS AUTOGENERATED BY CODE IN +// org.apache.spark.sql.api.java.FunctionRegistration +// ************************************************** + +/** + * A Spark SQL UDF that has 5 arguments. + */ +public interface UDF5 extends Serializable { + public R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5) throws Exception; +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF6.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF6.java new file mode 100644 index 0000000000000..f56b806684e61 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF6.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +import java.io.Serializable; + +// ************************************************** +// THIS FILE IS AUTOGENERATED BY CODE IN +// org.apache.spark.sql.api.java.FunctionRegistration +// ************************************************** + +/** + * A Spark SQL UDF that has 6 arguments. + */ +public interface UDF6 extends Serializable { + public R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6) throws Exception; +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF7.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF7.java new file mode 100644 index 0000000000000..25bd6d3241bd4 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF7.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +import java.io.Serializable; + +// ************************************************** +// THIS FILE IS AUTOGENERATED BY CODE IN +// org.apache.spark.sql.api.java.FunctionRegistration +// ************************************************** + +/** + * A Spark SQL UDF that has 7 arguments. + */ +public interface UDF7 extends Serializable { + public R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7) throws Exception; +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF8.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF8.java new file mode 100644 index 0000000000000..a3b7ac5f94ce7 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF8.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +import java.io.Serializable; + +// ************************************************** +// THIS FILE IS AUTOGENERATED BY CODE IN +// org.apache.spark.sql.api.java.FunctionRegistration +// ************************************************** + +/** + * A Spark SQL UDF that has 8 arguments. + */ +public interface UDF8 extends Serializable { + public R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8) throws Exception; +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF9.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF9.java new file mode 100644 index 0000000000000..205e72a1522fc --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF9.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +import java.io.Serializable; + +// ************************************************** +// THIS FILE IS AUTOGENERATED BY CODE IN +// org.apache.spark.sql.api.java.FunctionRegistration +// ************************************************** + +/** + * A Spark SQL UDF that has 9 arguments. + */ +public interface UDF9 extends Serializable { + public R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9) throws Exception; +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/package-info.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/package-info.java similarity index 95% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/package-info.java rename to sql/core/src/main/java/org/apache/spark/sql/api/java/package-info.java index 53603614518f5..67007a9f0d1a3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/package-info.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/package-info.java @@ -18,4 +18,4 @@ /** * Allows the execution of relational queries, including those expressed in SQL using Spark. */ -package org.apache.spark.sql; \ No newline at end of file +package org.apache.spark.sql.api.java; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 2b787e14f3f15..0fd7aaaa36eb8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -17,76 +17,128 @@ package org.apache.spark.sql +import scala.collection.immutable +import scala.collection.JavaConversions._ + import java.util.Properties -import scala.collection.JavaConverters._ + +private[spark] object SQLConf { + val COMPRESS_CACHED = "spark.sql.inMemoryColumnarStorage.compressed" + val AUTO_BROADCASTJOIN_THRESHOLD = "spark.sql.autoBroadcastJoinThreshold" + val DEFAULT_SIZE_IN_BYTES = "spark.sql.defaultSizeInBytes" + val SHUFFLE_PARTITIONS = "spark.sql.shuffle.partitions" + val CODEGEN_ENABLED = "spark.sql.codegen" + val DIALECT = "spark.sql.dialect" + + object Deprecated { + val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" + } +} /** - * SQLConf holds mutable config parameters and hints. These can be set and - * queried either by passing SET commands into Spark SQL's DSL - * functions (sql(), hql(), etc.), or by programmatically using setters and - * getters of this class. + * A trait that enables the setting and getting of mutable config parameters/hints. + * + * In the presence of a SQLContext, these can be set and queried by passing SET commands + * into Spark SQL's query functions (i.e. sql()). Otherwise, users of this trait can + * modify the hints by programmatically calling the setters and getters of this trait. * - * SQLConf is thread-safe (internally synchronized so safe to be used in multiple threads). + * SQLConf is thread-safe (internally synchronized, so safe to be used in multiple threads). */ trait SQLConf { + import SQLConf._ + + @transient protected[spark] val settings = java.util.Collections.synchronizedMap( + new java.util.HashMap[String, String]()) /** ************************ Spark SQL Params/Hints ******************* */ // TODO: refactor so that these hints accessors don't pollute the name space of SQLContext? + /** + * The SQL dialect that is used when parsing queries. This defaults to 'sql' which uses + * a simple SQL parser provided by Spark SQL. This is currently the only option for users of + * SQLContext. + * + * When using a HiveContext, this value defaults to 'hiveql', which uses the Hive 0.12.0 HiveQL + * parser. Users can change this to 'sql' if they want to run queries that aren't supported by + * HiveQL (e.g., SELECT 1). + * + * Note that the choice of dialect does not affect things like what tables are available or + * how query execution is performed. + */ + private[spark] def dialect: String = getConf(DIALECT, "sql") + + /** When true tables cached using the in-memory columnar caching will be compressed. */ + private[spark] def useCompression: Boolean = getConf(COMPRESS_CACHED, "false").toBoolean + /** Number of partitions to use for shuffle operators. */ - private[spark] def numShufflePartitions: Int = get("spark.sql.shuffle.partitions", "200").toInt + private[spark] def numShufflePartitions: Int = getConf(SHUFFLE_PARTITIONS, "200").toInt + + /** + * When set to true, Spark SQL will use the Scala compiler at runtime to generate custom bytecode + * that evaluates expressions found in queries. In general this custom code runs much faster + * than interpreted evaluation, but there are significant start-up costs due to compilation. + * As a result codegen is only benificial when queries run for a long time, or when the same + * expressions are used multiple times. + * + * Defaults to false as this feature is currently experimental. + */ + private[spark] def codegenEnabled: Boolean = + if (getConf(CODEGEN_ENABLED, "false") == "true") true else false /** * Upper bound on the sizes (in bytes) of the tables qualified for the auto conversion to - * a broadcast value during the physical executions of join operations. Setting this to 0 + * a broadcast value during the physical executions of join operations. Setting this to -1 * effectively disables auto conversion. - * Hive setting: hive.auto.convert.join.noconditionaltask.size. + * + * Hive setting: hive.auto.convert.join.noconditionaltask.size, whose default value is also 10000. */ - private[spark] def autoConvertJoinSize: Int = - get("spark.sql.auto.convert.join.size", "10000").toInt + private[spark] def autoBroadcastJoinThreshold: Int = + getConf(AUTO_BROADCASTJOIN_THRESHOLD, "10000").toInt - /** A comma-separated list of table names marked to be broadcasted during joins. */ - private[spark] def joinBroadcastTables: String = get("spark.sql.join.broadcastTables", "") + /** + * The default size in bytes to assign to a logical operator's estimation statistics. By default, + * it is set to a larger value than `autoConvertJoinSize`, hence any logical operator without a + * properly implemented estimation of this statistic will not be incorrectly broadcasted in joins. + */ + private[spark] def defaultSizeInBytes: Long = + getConf(DEFAULT_SIZE_IN_BYTES, (autoBroadcastJoinThreshold + 1).toString).toLong /** ********************** SQLConf functionality methods ************ */ - @transient - private val settings = java.util.Collections.synchronizedMap( - new java.util.HashMap[String, String]()) - - def set(props: Properties): Unit = { - props.asScala.foreach { case (k, v) => this.settings.put(k, v) } + /** Set Spark SQL configuration properties. */ + def setConf(props: Properties): Unit = settings.synchronized { + props.foreach { case (k, v) => settings.put(k, v) } } - def set(key: String, value: String): Unit = { + /** Set the given Spark SQL configuration property. */ + def setConf(key: String, value: String): Unit = { require(key != null, "key cannot be null") - require(value != null, s"value cannot be null for ${key}") + require(value != null, s"value cannot be null for key: $key") settings.put(key, value) } - def get(key: String): String = { + /** Return the value of Spark SQL configuration property for the given key. */ + def getConf(key: String): String = { Option(settings.get(key)).getOrElse(throw new NoSuchElementException(key)) } - def get(key: String, defaultValue: String): String = { + /** + * Return the value of Spark SQL configuration property for the given key. If the key is not set + * yet, return `defaultValue`. + */ + def getConf(key: String, defaultValue: String): String = { Option(settings.get(key)).getOrElse(defaultValue) } - def getAll: Array[(String, String)] = settings.synchronized { settings.asScala.toArray } - - def getOption(key: String): Option[String] = Option(settings.get(key)) - - def contains(key: String): Boolean = settings.containsKey(key) - - def toDebugString: String = { - settings.synchronized { - settings.asScala.toArray.sorted.map{ case (k, v) => s"$k=$v" }.mkString("\n") - } - } + /** + * Return all the configuration properties that have been set (i.e. not the default). + * This creates a new copy of the config properties in the form of a Map. + */ + def getAllConfs: immutable.Map[String, String] = settings.synchronized { settings.toMap } private[spark] def clear() { settings.clear() } - } + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 4abd89955bd27..71d338d21d0f2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -24,11 +24,10 @@ import org.apache.hadoop.conf.Configuration import org.apache.spark.annotation.{AlphaComponent, DeveloperApi, Experimental} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.ScalaReflection -import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.dsl.ExpressionConversions -import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.optimizer.Optimizer import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.RuleExecutor @@ -37,7 +36,7 @@ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.SparkStrategies import org.apache.spark.sql.json._ import org.apache.spark.sql.parquet.ParquetRelation -import org.apache.spark.SparkContext +import org.apache.spark.{Logging, SparkContext} /** * :: AlphaComponent :: @@ -49,18 +48,23 @@ import org.apache.spark.SparkContext */ @AlphaComponent class SQLContext(@transient val sparkContext: SparkContext) - extends Logging + extends org.apache.spark.Logging with SQLConf with ExpressionConversions + with UDFRegistration with Serializable { self => @transient protected[sql] lazy val catalog: Catalog = new SimpleCatalog(true) + + @transient + protected[sql] lazy val functionRegistry: FunctionRegistry = new SimpleFunctionRegistry + @transient protected[sql] lazy val analyzer: Analyzer = - new Analyzer(catalog, EmptyFunctionRegistry, caseSensitive = true) + new Analyzer(catalog, functionRegistry, caseSensitive = true) @transient protected[sql] val optimizer = Optimizer @transient @@ -86,7 +90,45 @@ class SQLContext(@transient val sparkContext: SparkContext) * @group userf */ implicit def createSchemaRDD[A <: Product: TypeTag](rdd: RDD[A]) = - new SchemaRDD(this, SparkLogicalPlan(ExistingRdd.fromProductRdd(rdd))) + new SchemaRDD(this, SparkLogicalPlan(ExistingRdd.fromProductRdd(rdd))(self)) + + /** + * :: DeveloperApi :: + * Creates a [[SchemaRDD]] from an [[RDD]] containing [[Row]]s by applying a schema to this RDD. + * It is important to make sure that the structure of every [[Row]] of the provided RDD matches + * the provided schema. Otherwise, there will be runtime exception. + * Example: + * {{{ + * import org.apache.spark.sql._ + * val sqlContext = new org.apache.spark.sql.SQLContext(sc) + * + * val schema = + * StructType( + * StructField("name", StringType, false) :: + * StructField("age", IntegerType, true) :: Nil) + * + * val people = + * sc.textFile("examples/src/main/resources/people.txt").map( + * _.split(",")).map(p => Row(p(0), p(1).trim.toInt)) + * val peopleSchemaRDD = sqlContext. applySchema(people, schema) + * peopleSchemaRDD.printSchema + * // root + * // |-- name: string (nullable = false) + * // |-- age: integer (nullable = true) + * + * peopleSchemaRDD.registerTempTable("people") + * sqlContext.sql("select name from people").collect.foreach(println) + * }}} + * + * @group userf + */ + @DeveloperApi + def applySchema(rowRDD: RDD[Row], schema: StructType): SchemaRDD = { + // TODO: use MutableProjection when rowRDD is another SchemaRDD and the applied + // schema differs from the existing schema on any field data type. + val logicalPlan = SparkLogicalPlan(ExistingRdd(schema.toAttributes, rowRDD))(self) + new SchemaRDD(this, logicalPlan) + } /** * Loads a Parquet file, returning the result as a [[SchemaRDD]]. @@ -94,7 +136,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * @group userf */ def parquetFile(path: String): SchemaRDD = - new SchemaRDD(this, parquet.ParquetRelation(path, Some(sparkContext.hadoopConfiguration))) + new SchemaRDD(this, parquet.ParquetRelation(path, Some(sparkContext.hadoopConfiguration), this)) /** * Loads a JSON file (one object per line), returning the result as a [[SchemaRDD]]. @@ -104,6 +146,19 @@ class SQLContext(@transient val sparkContext: SparkContext) */ def jsonFile(path: String): SchemaRDD = jsonFile(path, 1.0) + /** + * :: Experimental :: + * Loads a JSON file (one object per line) and applies the given schema, + * returning the result as a [[SchemaRDD]]. + * + * @group userf + */ + @Experimental + def jsonFile(path: String, schema: StructType): SchemaRDD = { + val json = sparkContext.textFile(path) + jsonRDD(json, schema) + } + /** * :: Experimental :: */ @@ -122,12 +177,30 @@ class SQLContext(@transient val sparkContext: SparkContext) */ def jsonRDD(json: RDD[String]): SchemaRDD = jsonRDD(json, 1.0) + /** + * :: Experimental :: + * Loads an RDD[String] storing JSON objects (one object per record) and applies the given schema, + * returning the result as a [[SchemaRDD]]. + * + * @group userf + */ + @Experimental + def jsonRDD(json: RDD[String], schema: StructType): SchemaRDD = { + val appliedSchema = + Option(schema).getOrElse(JsonRDD.nullTypeToStringType(JsonRDD.inferSchema(json, 1.0))) + val rowRDD = JsonRDD.jsonStringToRow(json, appliedSchema) + applySchema(rowRDD, appliedSchema) + } + /** * :: Experimental :: */ @Experimental - def jsonRDD(json: RDD[String], samplingRatio: Double): SchemaRDD = - new SchemaRDD(this, JsonRDD.inferSchema(json, samplingRatio)) + def jsonRDD(json: RDD[String], samplingRatio: Double): SchemaRDD = { + val appliedSchema = JsonRDD.nullTypeToStringType(JsonRDD.inferSchema(json, samplingRatio)) + val rowRDD = JsonRDD.jsonStringToRow(json, appliedSchema) + applySchema(rowRDD, appliedSchema) + } /** * :: Experimental :: @@ -139,7 +212,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * import sqlContext._ * * case class Person(name: String, age: Int) - * createParquetFile[Person]("path/to/file.parquet").registerAsTable("people") + * createParquetFile[Person]("path/to/file.parquet").registerTempTable("people") * sql("INSERT INTO people SELECT 'michael', 29") * }}} * @@ -160,7 +233,8 @@ class SQLContext(@transient val sparkContext: SparkContext) conf: Configuration = new Configuration()): SchemaRDD = { new SchemaRDD( this, - ParquetRelation.createEmpty(path, ScalaReflection.attributesFor[A], allowExisting, conf)) + ParquetRelation.createEmpty( + path, ScalaReflection.attributesFor[A], allowExisting, conf, this)) } /** @@ -170,19 +244,22 @@ class SQLContext(@transient val sparkContext: SparkContext) * @group userf */ def registerRDDAsTable(rdd: SchemaRDD, tableName: String): Unit = { - val name = tableName - val newPlan = rdd.logicalPlan transform { - case s @ SparkLogicalPlan(ExistingRdd(_, _), _) => s.copy(tableName = name) - } - catalog.registerTable(None, tableName, newPlan) + catalog.registerTable(None, tableName, rdd.logicalPlan) } /** - * Executes a SQL query using Spark, returning the result as a SchemaRDD. + * Executes a SQL query using Spark, returning the result as a SchemaRDD. The dialect that is + * used for SQL parsing can be configured with 'spark.sql.dialect'. * * @group userf */ - def sql(sqlText: String): SchemaRDD = new SchemaRDD(this, parseSql(sqlText)) + def sql(sqlText: String): SchemaRDD = { + if (dialect == "sql") { + new SchemaRDD(this, parseSql(sqlText)) + } else { + sys.error(s"Unsupported SQL dialect: $dialect") + } + } /** Returns the specified table as a SchemaRDD */ def table(tableName: String): SchemaRDD = @@ -196,8 +273,6 @@ class SQLContext(@transient val sparkContext: SparkContext) currentTable.logicalPlan case _ => - val useCompression = - sparkContext.conf.getBoolean("spark.sql.inMemoryColumnarStorage.compressed", false) InMemoryRelation(useCompression, executePlan(currentTable).executedPlan) } @@ -212,7 +287,7 @@ class SQLContext(@transient val sparkContext: SparkContext) case inMem @ InMemoryRelation(_, _, e: ExistingRdd) => inMem.cachedColumnBuffers.unpersist() catalog.unregisterTable(None, tableName) - catalog.registerTable(None, tableName, SparkLogicalPlan(e)) + catalog.registerTable(None, tableName, SparkLogicalPlan(e)(self)) case inMem: InMemoryRelation => inMem.cachedColumnBuffers.unpersist() catalog.unregisterTable(None, tableName) @@ -234,12 +309,14 @@ class SQLContext(@transient val sparkContext: SparkContext) val sqlContext: SQLContext = self + def codegenEnabled = self.codegenEnabled + def numPartitions = self.numShufflePartitions val strategies: Seq[Strategy] = CommandStrategy(self) :: TakeOrdered :: - PartialAggregation :: + HashAggregation :: LeftSemiJoin :: HashJoin :: InMemoryScans :: @@ -297,27 +374,30 @@ class SQLContext(@transient val sparkContext: SparkContext) protected[sql] lazy val emptyResult = sparkContext.parallelize(Seq.empty[Row], 1) /** - * Prepares a planned SparkPlan for execution by binding references to specific ordinals, and - * inserting shuffle operations as needed. + * Prepares a planned SparkPlan for execution by inserting shuffle operations as needed. */ @transient protected[sql] val prepareForExecution = new RuleExecutor[SparkPlan] { val batches = - Batch("Add exchange", Once, AddExchange(self)) :: - Batch("Prepare Expressions", Once, new BindReferences[SparkPlan]) :: Nil + Batch("Add exchange", Once, AddExchange(self)) :: Nil } /** + * :: DeveloperApi :: * The primary workflow for executing relational queries using Spark. Designed to allow easy * access to the intermediate phases of query execution for developers. */ + @DeveloperApi protected abstract class QueryExecution { def logical: LogicalPlan - lazy val analyzed = analyzer(logical) + lazy val analyzed = ExtractPythonUdfs(analyzer(logical)) lazy val optimizedPlan = optimizer(analyzed) // TODO: Don't just pick the first one... - lazy val sparkPlan = planner(optimizedPlan).next() + lazy val sparkPlan = { + SparkPlan.currentContext.set(self) + planner(optimizedPlan).next() + } // executedPlan should not be used to initialize any SparkPlan. It should be // only used for execution. lazy val executedPlan: SparkPlan = prepareForExecution(sparkPlan) @@ -337,41 +417,102 @@ class SQLContext(@transient val sparkContext: SparkContext) |${stringOrError(optimizedPlan)} |== Physical Plan == |${stringOrError(executedPlan)} + |Code Generation: ${executedPlan.codegenEnabled} + |== RDD == + |${stringOrError(toRdd.toDebugString)} """.stripMargin.trim } /** - * Peek at the first row of the RDD and infer its schema. - * TODO: consolidate this with the type system developed in SPARK-2060. + * Parses the data type in our internal string representation. The data type string should + * have the same format as the one generated by `toString` in scala. + * It is only used by PySpark. */ - private[sql] def inferSchema(rdd: RDD[Map[String, _]]): SchemaRDD = { + private[sql] def parseDataType(dataTypeString: String): DataType = { + val parser = org.apache.spark.sql.catalyst.types.DataType + parser(dataTypeString) + } + + /** + * Apply a schema defined by the schemaString to an RDD. It is only used by PySpark. + */ + private[sql] def applySchemaToPythonRDD( + rdd: RDD[Array[Any]], + schemaString: String): SchemaRDD = { + val schema = parseDataType(schemaString).asInstanceOf[StructType] + applySchemaToPythonRDD(rdd, schema) + } + + /** + * Apply a schema defined by the schema to an RDD. It is only used by PySpark. + */ + private[sql] def applySchemaToPythonRDD( + rdd: RDD[Array[Any]], + schema: StructType): SchemaRDD = { import scala.collection.JavaConversions._ - def typeFor(obj: Any): DataType = obj match { - case c: java.lang.String => StringType - case c: java.lang.Integer => IntegerType - case c: java.lang.Long => LongType - case c: java.lang.Double => DoubleType - case c: java.lang.Boolean => BooleanType - case c: java.util.List[_] => ArrayType(typeFor(c.head)) - case c: java.util.Set[_] => ArrayType(typeFor(c.head)) - case c: java.util.Map[_, _] => - val (key, value) = c.head - MapType(typeFor(key), typeFor(value)) - case c if c.getClass.isArray => - val elem = c.asInstanceOf[Array[_]].head - ArrayType(typeFor(elem)) - case c => throw new Exception(s"Object of type $c cannot be used") + import scala.collection.convert.Wrappers.{JListWrapper, JMapWrapper} + + def needsConversion(dataType: DataType): Boolean = dataType match { + case ByteType => true + case ShortType => true + case FloatType => true + case TimestampType => true + case ArrayType(_, _) => true + case MapType(_, _, _) => true + case StructType(_) => true + case other => false } - val schema = rdd.first().map { case (fieldName, obj) => - AttributeReference(fieldName, typeFor(obj), true)() - }.toSeq - val rowRdd = rdd.mapPartitions { iter => - iter.map { map => - new GenericRow(map.values.toArray.asInstanceOf[Array[Any]]): Row - } + // Converts value to the type specified by the data type. + // Because Python does not have data types for TimestampType, FloatType, ShortType, and + // ByteType, we need to explicitly convert values in columns of these data types to the desired + // JVM data types. + def convert(obj: Any, dataType: DataType): Any = (obj, dataType) match { + // TODO: We should check nullable + case (null, _) => null + + case (c: java.util.List[_], ArrayType(elementType, _)) => + val converted = c.map { e => convert(e, elementType)} + JListWrapper(converted) + + case (c, ArrayType(elementType, _)) if c.getClass.isArray => + c.asInstanceOf[Array[_]].map(e => convert(e, elementType)): Seq[Any] + + case (c: java.util.Map[_, _], MapType(keyType, valueType, _)) => c.map { + case (key, value) => (convert(key, keyType), convert(value, valueType)) + }.toMap + + case (c, StructType(fields)) if c.getClass.isArray => + new GenericRow(c.asInstanceOf[Array[_]].zip(fields).map { + case (e, f) => convert(e, f.dataType) + }): Row + + case (c: java.util.Calendar, TimestampType) => + new java.sql.Timestamp(c.getTime().getTime()) + + case (c: Int, ByteType) => c.toByte + case (c: Long, ByteType) => c.toByte + case (c: Int, ShortType) => c.toShort + case (c: Long, ShortType) => c.toShort + case (c: Long, IntegerType) => c.toInt + case (c: Double, FloatType) => c.toFloat + case (c, StringType) if !c.isInstanceOf[String] => c.toString + + case (c, _) => c + } + + val convertedRdd = if (schema.fields.exists(f => needsConversion(f.dataType))) { + rdd.map(m => m.zip(schema.fields).map { + case (value, field) => convert(value, field.dataType) + }) + } else { + rdd } - new SchemaRDD(this, SparkLogicalPlan(ExistingRdd(schema, rowRdd))) - } + val rowRdd = convertedRdd.mapPartitions { iter => + iter.map { m => new GenericRow(m): Row} + } + + new SchemaRDD(this, SparkLogicalPlan(ExistingRdd(schema.toAttributes, rowRdd))(self)) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala index 31d27bb4f0571..57df79321b35d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql -import java.util.{Map => JMap, List => JList, Set => JSet} +import java.util.{Map => JMap, List => JList} import scala.collection.JavaConversions._ import scala.collection.JavaConverters._ @@ -32,7 +32,6 @@ import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} -import org.apache.spark.sql.catalyst.types.{ArrayType, BooleanType, StructType} import org.apache.spark.sql.execution.{ExistingRdd, SparkLogicalPlan} import org.apache.spark.api.java.JavaRDD @@ -68,7 +67,7 @@ import org.apache.spark.api.java.JavaRDD * val rdd = sc.parallelize((1 to 100).map(i => Record(i, s"val_$i"))) * // Any RDD containing case classes can be registered as a table. The schema of the table is * // automatically inferred using scala reflection. - * rdd.registerAsTable("records") + * rdd.registerTempTable("records") * * val results: SchemaRDD = sql("SELECT * FROM records") * }}} @@ -120,6 +119,11 @@ class SchemaRDD( override protected def getDependencies: Seq[Dependency[_]] = List(new OneToOneDependency(queryExecution.toRdd)) + /** Returns the schema of this SchemaRDD (represented by a [[StructType]]). + * + * @group schema + */ + def schema: StructType = queryExecution.analyzed.schema // ======================================================================= // Query DSL @@ -376,48 +380,36 @@ class SchemaRDD( * Converts a JavaRDD to a PythonRDD. It is used by pyspark. */ private[sql] def javaToPython: JavaRDD[Array[Byte]] = { - def rowToMap(row: Row, structType: StructType): JMap[String, Any] = { - val fields = structType.fields.map(field => (field.name, field.dataType)) - val map: JMap[String, Any] = new java.util.HashMap - row.zip(fields).foreach { - case (obj, (attrName, dataType)) => - dataType match { - case struct: StructType => map.put(attrName, rowToMap(obj.asInstanceOf[Row], struct)) - case array @ ArrayType(struct: StructType) => - val arrayValues = obj match { - case seq: Seq[Any] => - seq.map(element => rowToMap(element.asInstanceOf[Row], struct)).asJava - case list: JList[_] => - list.map(element => rowToMap(element.asInstanceOf[Row], struct)) - case set: JSet[_] => - set.map(element => rowToMap(element.asInstanceOf[Row], struct)) - case arr if arr != null && arr.getClass.isArray => - arr.asInstanceOf[Array[Any]].map { - element => rowToMap(element.asInstanceOf[Row], struct) - } - case other => other - } - map.put(attrName, arrayValues) - case array: ArrayType => { - val arrayValues = obj match { - case seq: Seq[Any] => seq.asJava - case other => other - } - map.put(attrName, arrayValues) - } - case other => map.put(attrName, obj) - } + import scala.collection.Map + + def toJava(obj: Any, dataType: DataType): Any = dataType match { + case struct: StructType => rowToArray(obj.asInstanceOf[Row], struct) + case array: ArrayType => obj match { + case seq: Seq[Any] => seq.map(x => toJava(x, array.elementType)).asJava + case list: JList[_] => list.map(x => toJava(x, array.elementType)).asJava + case arr if arr != null && arr.getClass.isArray => + arr.asInstanceOf[Array[Any]].map(x => toJava(x, array.elementType)) + case other => other } - - map + case mt: MapType => obj.asInstanceOf[Map[_, _]].map { + case (k, v) => (k, toJava(v, mt.valueType)) // key should be primitive type + }.asJava + // Pyrolite can handle Timestamp + case other => obj + } + def rowToArray(row: Row, structType: StructType): Array[Any] = { + val fields = structType.fields.map(field => field.dataType) + row.zip(fields).map { + case (obj, dataType) => toJava(obj, dataType) + }.toArray } val rowSchema = StructType.fromAttributes(this.queryExecution.analyzed.output) this.mapPartitions { iter => val pickle = new Pickler iter.map { row => - rowToMap(row, rowSchema) - }.grouped(10).map(batched => pickle.dumps(batched.toArray)) + rowToArray(row, rowSchema) + }.grouped(100).map(batched => pickle.dumps(batched.toArray)) } } @@ -430,7 +422,8 @@ class SchemaRDD( * @group schema */ private def applySchema(rdd: RDD[Row]): SchemaRDD = { - new SchemaRDD(sqlContext, SparkLogicalPlan(ExistingRdd(queryExecution.analyzed.output, rdd))) + new SchemaRDD(sqlContext, + SparkLogicalPlan(ExistingRdd(queryExecution.analyzed.output, rdd))(sqlContext)) } // ======================================================================= diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala index fe81721943202..2f3033a5f94f0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala @@ -56,7 +56,7 @@ private[sql] trait SchemaRDDLike { // happen right away to let these side effects take place eagerly. case _: Command | _: InsertIntoTable | _: InsertIntoCreatedTable | _: WriteToFile => queryExecution.toRdd - SparkLogicalPlan(queryExecution.executedPlan) + SparkLogicalPlan(queryExecution.executedPlan)(sqlContext) case _ => baseLogicalPlan } @@ -83,10 +83,13 @@ private[sql] trait SchemaRDDLike { * * @group schema */ - def registerAsTable(tableName: String): Unit = { + def registerTempTable(tableName: String): Unit = { sqlContext.registerRDDAsTable(baseSchemaRDD, tableName) } + @deprecated("Use registerTempTable instead of registerAsTable.", "1.1") + def registerAsTable(tableName: String): Unit = registerTempTable(tableName) + /** * :: Experimental :: * Adds the rows from this RDD to the specified table, optionally overwriting the existing data. @@ -123,9 +126,15 @@ private[sql] trait SchemaRDDLike { def saveAsTable(tableName: String): Unit = sqlContext.executePlan(InsertIntoCreatedTable(None, tableName, logicalPlan)).toRdd - /** Returns the output schema in the tree format. */ - def schemaString: String = queryExecution.analyzed.schemaString + /** Returns the schema as a string in the tree format. + * + * @group schema + */ + def schemaString: String = baseSchemaRDD.schema.treeString - /** Prints out the schema in the tree format. */ + /** Prints out the schema. + * + * @group schema + */ def printSchema(): Unit = println(schemaString) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala new file mode 100644 index 0000000000000..0b48e9e659faa --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala @@ -0,0 +1,196 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.util.{List => JList, Map => JMap} + +import org.apache.spark.Accumulator +import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUdf} +import org.apache.spark.sql.execution.PythonUDF + +import scala.reflect.runtime.universe.{TypeTag, typeTag} + +/** + * Functions for registering scala lambda functions as UDFs in a SQLContext. + */ +protected[sql] trait UDFRegistration { + self: SQLContext => + + private[spark] def registerPython( + name: String, + command: Array[Byte], + envVars: JMap[String, String], + pythonIncludes: JList[String], + pythonExec: String, + accumulator: Accumulator[JList[Array[Byte]]], + stringDataType: String): Unit = { + log.debug( + s""" + | Registering new PythonUDF: + | name: $name + | command: ${command.toSeq} + | envVars: $envVars + | pythonIncludes: $pythonIncludes + | pythonExec: $pythonExec + | dataType: $stringDataType + """.stripMargin) + + + val dataType = parseDataType(stringDataType) + + def builder(e: Seq[Expression]) = + PythonUDF( + name, + command, + envVars, + pythonIncludes, + pythonExec, + accumulator, + dataType, + e) + + functionRegistry.registerFunction(name, builder) + } + + /** registerFunction 1-22 were generated by this script + + (1 to 22).map { x => + val types = (1 to x).map(x => "_").reduce(_ + ", " + _) + s""" + def registerFunction[T: TypeTag](name: String, func: Function$x[$types, T]): Unit = { + def builder(e: Seq[Expression]) = + ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + functionRegistry.registerFunction(name, builder) + } + """ + } + */ + + // scalastyle:off + def registerFunction[T: TypeTag](name: String, func: Function1[_, T]): Unit = { + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + functionRegistry.registerFunction(name, builder) + } + + def registerFunction[T: TypeTag](name: String, func: Function2[_, _, T]): Unit = { + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + functionRegistry.registerFunction(name, builder) + } + + def registerFunction[T: TypeTag](name: String, func: Function3[_, _, _, T]): Unit = { + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + functionRegistry.registerFunction(name, builder) + } + + def registerFunction[T: TypeTag](name: String, func: Function4[_, _, _, _, T]): Unit = { + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + functionRegistry.registerFunction(name, builder) + } + + def registerFunction[T: TypeTag](name: String, func: Function5[_, _, _, _, _, T]): Unit = { + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + functionRegistry.registerFunction(name, builder) + } + + def registerFunction[T: TypeTag](name: String, func: Function6[_, _, _, _, _, _, T]): Unit = { + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + functionRegistry.registerFunction(name, builder) + } + + def registerFunction[T: TypeTag](name: String, func: Function7[_, _, _, _, _, _, _, T]): Unit = { + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + functionRegistry.registerFunction(name, builder) + } + + def registerFunction[T: TypeTag](name: String, func: Function8[_, _, _, _, _, _, _, _, T]): Unit = { + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + functionRegistry.registerFunction(name, builder) + } + + def registerFunction[T: TypeTag](name: String, func: Function9[_, _, _, _, _, _, _, _, _, T]): Unit = { + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + functionRegistry.registerFunction(name, builder) + } + + def registerFunction[T: TypeTag](name: String, func: Function10[_, _, _, _, _, _, _, _, _, _, T]): Unit = { + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + functionRegistry.registerFunction(name, builder) + } + + def registerFunction[T: TypeTag](name: String, func: Function11[_, _, _, _, _, _, _, _, _, _, _, T]): Unit = { + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + functionRegistry.registerFunction(name, builder) + } + + def registerFunction[T: TypeTag](name: String, func: Function12[_, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = { + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + functionRegistry.registerFunction(name, builder) + } + + def registerFunction[T: TypeTag](name: String, func: Function13[_, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = { + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + functionRegistry.registerFunction(name, builder) + } + + def registerFunction[T: TypeTag](name: String, func: Function14[_, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = { + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + functionRegistry.registerFunction(name, builder) + } + + def registerFunction[T: TypeTag](name: String, func: Function15[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = { + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + functionRegistry.registerFunction(name, builder) + } + + def registerFunction[T: TypeTag](name: String, func: Function16[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = { + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + functionRegistry.registerFunction(name, builder) + } + + def registerFunction[T: TypeTag](name: String, func: Function17[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = { + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + functionRegistry.registerFunction(name, builder) + } + + def registerFunction[T: TypeTag](name: String, func: Function18[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = { + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + functionRegistry.registerFunction(name, builder) + } + + def registerFunction[T: TypeTag](name: String, func: Function19[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = { + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + functionRegistry.registerFunction(name, builder) + } + + def registerFunction[T: TypeTag](name: String, func: Function20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = { + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + functionRegistry.registerFunction(name, builder) + } + + def registerFunction[T: TypeTag](name: String, func: Function21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = { + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + functionRegistry.registerFunction(name, builder) + } + + def registerFunction[T: TypeTag](name: String, func: Function22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = { + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + functionRegistry.registerFunction(name, builder) + } + // scalastyle:on +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala index 790d9ef22cf16..150ff8a42063d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala @@ -21,28 +21,36 @@ import java.beans.Introspector import org.apache.hadoop.conf.Configuration -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} import org.apache.spark.sql.json.JsonRDD -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.{SQLContext, StructType => SStructType} import org.apache.spark.sql.catalyst.expressions.{AttributeReference, GenericRow, Row => ScalaRow} -import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.parquet.ParquetRelation import org.apache.spark.sql.execution.{ExistingRdd, SparkLogicalPlan} +import org.apache.spark.sql.types.util.DataTypeConversions.asScalaDataType import org.apache.spark.util.Utils /** * The entry point for executing Spark SQL queries from a Java program. */ -class JavaSQLContext(val sqlContext: SQLContext) { +class JavaSQLContext(val sqlContext: SQLContext) extends UDFRegistration { def this(sparkContext: JavaSparkContext) = this(new SQLContext(sparkContext.sc)) /** - * Executes a query expressed in SQL, returning the result as a JavaSchemaRDD + * Executes a SQL query using Spark, returning the result as a SchemaRDD. The dialect that is + * used for SQL parsing can be configured with 'spark.sql.dialect'. + * + * @group userf */ - def sql(sqlQuery: String): JavaSchemaRDD = - new JavaSchemaRDD(sqlContext, sqlContext.parseSql(sqlQuery)) + def sql(sqlText: String): JavaSchemaRDD = { + if (sqlContext.dialect == "sql") { + new JavaSchemaRDD(sqlContext, sqlContext.parseSql(sqlText)) + } else { + sys.error(s"Unsupported SQL dialect: $sqlContext.dialect") + } + } /** * :: Experimental :: @@ -52,7 +60,7 @@ class JavaSQLContext(val sqlContext: SQLContext) { * {{{ * JavaSQLContext sqlCtx = new JavaSQLContext(...) * - * sqlCtx.createParquetFile(Person.class, "path/to/file.parquet").registerAsTable("people") + * sqlCtx.createParquetFile(Person.class, "path/to/file.parquet").registerTempTable("people") * sqlCtx.sql("INSERT INTO people SELECT 'michael', 29") * }}} * @@ -72,7 +80,7 @@ class JavaSQLContext(val sqlContext: SQLContext) { conf: Configuration = new Configuration()): JavaSchemaRDD = { new JavaSchemaRDD( sqlContext, - ParquetRelation.createEmpty(path, getSchema(beanClass), allowExisting, conf)) + ParquetRelation.createEmpty(path, getSchema(beanClass), allowExisting, conf, sqlContext)) } /** @@ -92,7 +100,22 @@ class JavaSQLContext(val sqlContext: SQLContext) { new GenericRow(extractors.map(e => e.invoke(row)).toArray[Any]): ScalaRow } } - new JavaSchemaRDD(sqlContext, SparkLogicalPlan(ExistingRdd(schema, rowRdd))) + new JavaSchemaRDD(sqlContext, SparkLogicalPlan(ExistingRdd(schema, rowRdd))(sqlContext)) + } + + /** + * :: DeveloperApi :: + * Creates a JavaSchemaRDD from an RDD containing Rows by applying a schema to this RDD. + * It is important to make sure that the structure of every Row of the provided RDD matches the + * provided schema. Otherwise, there will be runtime exception. + */ + @DeveloperApi + def applySchema(rowRDD: JavaRDD[Row], schema: StructType): JavaSchemaRDD = { + val scalaRowRDD = rowRDD.rdd.map(r => r.row) + val scalaSchema = asScalaDataType(schema).asInstanceOf[SStructType] + val logicalPlan = + SparkLogicalPlan(ExistingRdd(scalaSchema.toAttributes, scalaRowRDD))(sqlContext) + new JavaSchemaRDD(sqlContext, logicalPlan) } /** @@ -101,26 +124,52 @@ class JavaSQLContext(val sqlContext: SQLContext) { def parquetFile(path: String): JavaSchemaRDD = new JavaSchemaRDD( sqlContext, - ParquetRelation(path, Some(sqlContext.sparkContext.hadoopConfiguration))) + ParquetRelation(path, Some(sqlContext.sparkContext.hadoopConfiguration), sqlContext)) /** - * Loads a JSON file (one object per line), returning the result as a [[JavaSchemaRDD]]. + * Loads a JSON file (one object per line), returning the result as a JavaSchemaRDD. * It goes through the entire dataset once to determine the schema. - * - * @group userf */ def jsonFile(path: String): JavaSchemaRDD = jsonRDD(sqlContext.sparkContext.textFile(path)) + /** + * :: Experimental :: + * Loads a JSON file (one object per line) and applies the given schema, + * returning the result as a JavaSchemaRDD. + */ + @Experimental + def jsonFile(path: String, schema: StructType): JavaSchemaRDD = + jsonRDD(sqlContext.sparkContext.textFile(path), schema) + /** * Loads an RDD[String] storing JSON objects (one object per record), returning the result as a - * [[JavaSchemaRDD]]. + * JavaSchemaRDD. * It goes through the entire dataset once to determine the schema. - * - * @group userf */ - def jsonRDD(json: JavaRDD[String]): JavaSchemaRDD = - new JavaSchemaRDD(sqlContext, JsonRDD.inferSchema(json, 1.0)) + def jsonRDD(json: JavaRDD[String]): JavaSchemaRDD = { + val appliedScalaSchema = JsonRDD.nullTypeToStringType(JsonRDD.inferSchema(json.rdd, 1.0)) + val scalaRowRDD = JsonRDD.jsonStringToRow(json.rdd, appliedScalaSchema) + val logicalPlan = + SparkLogicalPlan(ExistingRdd(appliedScalaSchema.toAttributes, scalaRowRDD))(sqlContext) + new JavaSchemaRDD(sqlContext, logicalPlan) + } + + /** + * :: Experimental :: + * Loads an RDD[String] storing JSON objects (one object per record) and applies the given schema, + * returning the result as a JavaSchemaRDD. + */ + @Experimental + def jsonRDD(json: JavaRDD[String], schema: StructType): JavaSchemaRDD = { + val appliedScalaSchema = + Option(asScalaDataType(schema)).getOrElse( + JsonRDD.nullTypeToStringType(JsonRDD.inferSchema(json.rdd, 1.0))).asInstanceOf[SStructType] + val scalaRowRDD = JsonRDD.jsonStringToRow(json.rdd, appliedScalaSchema) + val logicalPlan = + SparkLogicalPlan(ExistingRdd(appliedScalaSchema.toAttributes, scalaRowRDD))(sqlContext) + new JavaSchemaRDD(sqlContext, logicalPlan) + } /** * Registers the given RDD as a temporary table in the catalog. Temporary tables exist only @@ -138,22 +187,37 @@ class JavaSQLContext(val sqlContext: SQLContext) { val fields = beanInfo.getPropertyDescriptors.filterNot(_.getName == "class") fields.map { property => val (dataType, nullable) = property.getPropertyType match { - case c: Class[_] if c == classOf[java.lang.String] => (StringType, true) - case c: Class[_] if c == java.lang.Short.TYPE => (ShortType, false) - case c: Class[_] if c == java.lang.Integer.TYPE => (IntegerType, false) - case c: Class[_] if c == java.lang.Long.TYPE => (LongType, false) - case c: Class[_] if c == java.lang.Double.TYPE => (DoubleType, false) - case c: Class[_] if c == java.lang.Byte.TYPE => (ByteType, false) - case c: Class[_] if c == java.lang.Float.TYPE => (FloatType, false) - case c: Class[_] if c == java.lang.Boolean.TYPE => (BooleanType, false) - - case c: Class[_] if c == classOf[java.lang.Short] => (ShortType, true) - case c: Class[_] if c == classOf[java.lang.Integer] => (IntegerType, true) - case c: Class[_] if c == classOf[java.lang.Long] => (LongType, true) - case c: Class[_] if c == classOf[java.lang.Double] => (DoubleType, true) - case c: Class[_] if c == classOf[java.lang.Byte] => (ByteType, true) - case c: Class[_] if c == classOf[java.lang.Float] => (FloatType, true) - case c: Class[_] if c == classOf[java.lang.Boolean] => (BooleanType, true) + case c: Class[_] if c == classOf[java.lang.String] => + (org.apache.spark.sql.StringType, true) + case c: Class[_] if c == java.lang.Short.TYPE => + (org.apache.spark.sql.ShortType, false) + case c: Class[_] if c == java.lang.Integer.TYPE => + (org.apache.spark.sql.IntegerType, false) + case c: Class[_] if c == java.lang.Long.TYPE => + (org.apache.spark.sql.LongType, false) + case c: Class[_] if c == java.lang.Double.TYPE => + (org.apache.spark.sql.DoubleType, false) + case c: Class[_] if c == java.lang.Byte.TYPE => + (org.apache.spark.sql.ByteType, false) + case c: Class[_] if c == java.lang.Float.TYPE => + (org.apache.spark.sql.FloatType, false) + case c: Class[_] if c == java.lang.Boolean.TYPE => + (org.apache.spark.sql.BooleanType, false) + + case c: Class[_] if c == classOf[java.lang.Short] => + (org.apache.spark.sql.ShortType, true) + case c: Class[_] if c == classOf[java.lang.Integer] => + (org.apache.spark.sql.IntegerType, true) + case c: Class[_] if c == classOf[java.lang.Long] => + (org.apache.spark.sql.LongType, true) + case c: Class[_] if c == classOf[java.lang.Double] => + (org.apache.spark.sql.DoubleType, true) + case c: Class[_] if c == classOf[java.lang.Byte] => + (org.apache.spark.sql.ByteType, true) + case c: Class[_] if c == classOf[java.lang.Float] => + (org.apache.spark.sql.FloatType, true) + case c: Class[_] if c == classOf[java.lang.Boolean] => + (org.apache.spark.sql.BooleanType, true) } AttributeReference(property.getName, dataType, nullable)() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala index 8fbf13b8b0150..4d799b4038fdd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala @@ -22,8 +22,10 @@ import java.util.{List => JList} import org.apache.spark.Partitioner import org.apache.spark.api.java.{JavaRDDLike, JavaRDD} import org.apache.spark.api.java.function.{Function => JFunction} +import org.apache.spark.sql.types.util.DataTypeConversions import org.apache.spark.sql.{SQLContext, SchemaRDD, SchemaRDDLike} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import DataTypeConversions._ import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel @@ -53,6 +55,10 @@ class JavaSchemaRDD( override def toString: String = baseSchemaRDD.toString + /** Returns the schema of this JavaSchemaRDD (represented by a StructType). */ + def schema: StructType = + asJavaDataType(baseSchemaRDD.schema).asInstanceOf[StructType] + // ======================================================================= // Base RDD functions that do NOT change schema // ======================================================================= diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala index 9b0dd2176149b..6c67934bda5b8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala @@ -17,6 +17,11 @@ package org.apache.spark.sql.api.java +import scala.annotation.varargs +import scala.collection.convert.Wrappers.{JListWrapper, JMapWrapper} +import scala.collection.JavaConversions +import scala.math.BigDecimal + import org.apache.spark.sql.catalyst.expressions.{Row => ScalaRow} /** @@ -29,7 +34,7 @@ class Row(private[spark] val row: ScalaRow) extends Serializable { /** Returns the value of column `i`. */ def get(i: Int): Any = - row(i) + Row.toJavaValue(row(i)) /** Returns true if value at column `i` is NULL. */ def isNullAt(i: Int) = get(i) == null @@ -89,5 +94,57 @@ class Row(private[spark] val row: ScalaRow) extends Serializable { */ def getString(i: Int): String = row.getString(i) + + def canEqual(other: Any): Boolean = other.isInstanceOf[Row] + + override def equals(other: Any): Boolean = other match { + case that: Row => + (that canEqual this) && + row == that.row + case _ => false + } + + override def hashCode(): Int = row.hashCode() } +object Row { + + private def toJavaValue(value: Any): Any = value match { + // For values of this ScalaRow, we will do the conversion when + // they are actually accessed. + case row: ScalaRow => new Row(row) + case map: scala.collection.Map[_, _] => + JavaConversions.mapAsJavaMap( + map.map { + case (key, value) => (toJavaValue(key), toJavaValue(value)) + } + ) + case seq: scala.collection.Seq[_] => + JavaConversions.seqAsJavaList(seq.map(toJavaValue)) + case decimal: BigDecimal => decimal.underlying() + case other => other + } + + // TODO: Consolidate the toScalaValue at here with the scalafy in JsonRDD? + private def toScalaValue(value: Any): Any = value match { + // Values of this row have been converted to Scala values. + case row: Row => row.row + case map: java.util.Map[_, _] => + JMapWrapper(map).map { + case (key, value) => (toScalaValue(key), toScalaValue(value)) + } + case list: java.util.List[_] => + JListWrapper(list).map(toScalaValue) + case decimal: java.math.BigDecimal => BigDecimal(decimal) + case other => other + } + + /** + * Creates a Row with the given values. + */ + @varargs def create(values: Any*): Row = { + // Right now, we cannot use @varargs to annotate the constructor of + // org.apache.spark.sql.api.java.Row. See https://issues.scala-lang.org/browse/SI-8383. + new Row(ScalaRow(values.map(toScalaValue):_*)) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/UDFRegistration.scala new file mode 100644 index 0000000000000..158f26e3d445f --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/UDFRegistration.scala @@ -0,0 +1,252 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.api.java + +import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUdf} +import org.apache.spark.sql.types.util.DataTypeConversions._ + +/** + * A collection of functions that allow Java users to register UDFs. In order to handle functions + * of varying airities with minimal boilerplate for our users, we generate classes and functions + * for each airity up to 22. The code for this generation can be found in comments in this trait. + */ +private[java] trait UDFRegistration { + self: JavaSQLContext => + + /* The following functions and required interfaces are generated with these code fragments: + + (1 to 22).foreach { i => + val extTypeArgs = (1 to i).map(_ => "_").mkString(", ") + val anyTypeArgs = (1 to i).map(_ => "Any").mkString(", ") + val anyCast = s".asInstanceOf[UDF$i[$anyTypeArgs, Any]]" + val anyParams = (1 to i).map(_ => "_: Any").mkString(", ") + println(s""" + |def registerFunction( + | name: String, f: UDF$i[$extTypeArgs, _], @transient dataType: DataType) = { + | val scalaType = asScalaDataType(dataType) + | sqlContext.functionRegistry.registerFunction( + | name, + | (e: Seq[Expression]) => ScalaUdf(f$anyCast.call($anyParams), scalaType, e)) + |} + """.stripMargin) + } + + import java.io.File + import org.apache.spark.sql.catalyst.util.stringToFile + val directory = new File("sql/core/src/main/java/org/apache/spark/sql/api/java/") + (1 to 22).foreach { i => + val typeArgs = (1 to i).map(i => s"T$i").mkString(", ") + val args = (1 to i).map(i => s"T$i t$i").mkString(", ") + + val contents = + s"""/* + | * Licensed to the Apache Software Foundation (ASF) under one or more + | * contributor license agreements. See the NOTICE file distributed with + | * this work for additional information regarding copyright ownership. + | * The ASF licenses this file to You under the Apache License, Version 2.0 + | * (the "License"); you may not use this file except in compliance with + | * the License. You may obtain a copy of the License at + | * + | * http://www.apache.org/licenses/LICENSE-2.0 + | * + | * Unless required by applicable law or agreed to in writing, software + | * distributed under the License is distributed on an "AS IS" BASIS, + | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + | * See the License for the specific language governing permissions and + | * limitations under the License. + | */ + | + |package org.apache.spark.sql.api.java; + | + |import java.io.Serializable; + | + |// ************************************************** + |// THIS FILE IS AUTOGENERATED BY CODE IN + |// org.apache.spark.sql.api.java.FunctionRegistration + |// ************************************************** + | + |/** + | * A Spark SQL UDF that has $i arguments. + | */ + |public interface UDF$i<$typeArgs, R> extends Serializable { + | public R call($args) throws Exception; + |} + |""".stripMargin + + stringToFile(new File(directory, s"UDF$i.java"), contents) + } + + */ + + // scalastyle:off + def registerFunction(name: String, f: UDF1[_, _], dataType: DataType) = { + val scalaType = asScalaDataType(dataType) + sqlContext.functionRegistry.registerFunction( + name, + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF1[Any, Any]].call(_: Any), scalaType, e)) + } + + def registerFunction(name: String, f: UDF2[_, _, _], dataType: DataType) = { + val scalaType = asScalaDataType(dataType) + sqlContext.functionRegistry.registerFunction( + name, + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF2[Any, Any, Any]].call(_: Any, _: Any), scalaType, e)) + } + + def registerFunction(name: String, f: UDF3[_, _, _, _], dataType: DataType) = { + val scalaType = asScalaDataType(dataType) + sqlContext.functionRegistry.registerFunction( + name, + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF3[Any, Any, Any, Any]].call(_: Any, _: Any, _: Any), scalaType, e)) + } + + def registerFunction(name: String, f: UDF4[_, _, _, _, _], dataType: DataType) = { + val scalaType = asScalaDataType(dataType) + sqlContext.functionRegistry.registerFunction( + name, + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF4[Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any), scalaType, e)) + } + + def registerFunction(name: String, f: UDF5[_, _, _, _, _, _], dataType: DataType) = { + val scalaType = asScalaDataType(dataType) + sqlContext.functionRegistry.registerFunction( + name, + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF5[Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any), scalaType, e)) + } + + def registerFunction(name: String, f: UDF6[_, _, _, _, _, _, _], dataType: DataType) = { + val scalaType = asScalaDataType(dataType) + sqlContext.functionRegistry.registerFunction( + name, + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF6[Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any), scalaType, e)) + } + + def registerFunction(name: String, f: UDF7[_, _, _, _, _, _, _, _], dataType: DataType) = { + val scalaType = asScalaDataType(dataType) + sqlContext.functionRegistry.registerFunction( + name, + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF7[Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), scalaType, e)) + } + + def registerFunction(name: String, f: UDF8[_, _, _, _, _, _, _, _, _], dataType: DataType) = { + val scalaType = asScalaDataType(dataType) + sqlContext.functionRegistry.registerFunction( + name, + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF8[Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), scalaType, e)) + } + + def registerFunction(name: String, f: UDF9[_, _, _, _, _, _, _, _, _, _], dataType: DataType) = { + val scalaType = asScalaDataType(dataType) + sqlContext.functionRegistry.registerFunction( + name, + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF9[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), scalaType, e)) + } + + def registerFunction(name: String, f: UDF10[_, _, _, _, _, _, _, _, _, _, _], dataType: DataType) = { + val scalaType = asScalaDataType(dataType) + sqlContext.functionRegistry.registerFunction( + name, + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF10[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), scalaType, e)) + } + + def registerFunction(name: String, f: UDF11[_, _, _, _, _, _, _, _, _, _, _, _], dataType: DataType) = { + val scalaType = asScalaDataType(dataType) + sqlContext.functionRegistry.registerFunction( + name, + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF11[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), scalaType, e)) + } + + def registerFunction(name: String, f: UDF12[_, _, _, _, _, _, _, _, _, _, _, _, _], dataType: DataType) = { + val scalaType = asScalaDataType(dataType) + sqlContext.functionRegistry.registerFunction( + name, + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF12[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), scalaType, e)) + } + + def registerFunction(name: String, f: UDF13[_, _, _, _, _, _, _, _, _, _, _, _, _, _], dataType: DataType) = { + val scalaType = asScalaDataType(dataType) + sqlContext.functionRegistry.registerFunction( + name, + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF13[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), scalaType, e)) + } + + def registerFunction(name: String, f: UDF14[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _], dataType: DataType) = { + val scalaType = asScalaDataType(dataType) + sqlContext.functionRegistry.registerFunction( + name, + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF14[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), scalaType, e)) + } + + def registerFunction(name: String, f: UDF15[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], dataType: DataType) = { + val scalaType = asScalaDataType(dataType) + sqlContext.functionRegistry.registerFunction( + name, + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF15[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), scalaType, e)) + } + + def registerFunction(name: String, f: UDF16[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], dataType: DataType) = { + val scalaType = asScalaDataType(dataType) + sqlContext.functionRegistry.registerFunction( + name, + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF16[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), scalaType, e)) + } + + def registerFunction(name: String, f: UDF17[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], dataType: DataType) = { + val scalaType = asScalaDataType(dataType) + sqlContext.functionRegistry.registerFunction( + name, + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF17[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), scalaType, e)) + } + + def registerFunction(name: String, f: UDF18[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], dataType: DataType) = { + val scalaType = asScalaDataType(dataType) + sqlContext.functionRegistry.registerFunction( + name, + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF18[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), scalaType, e)) + } + + def registerFunction(name: String, f: UDF19[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], dataType: DataType) = { + val scalaType = asScalaDataType(dataType) + sqlContext.functionRegistry.registerFunction( + name, + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF19[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), scalaType, e)) + } + + def registerFunction(name: String, f: UDF20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], dataType: DataType) = { + val scalaType = asScalaDataType(dataType) + sqlContext.functionRegistry.registerFunction( + name, + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF20[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), scalaType, e)) + } + + def registerFunction(name: String, f: UDF21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], dataType: DataType) = { + val scalaType = asScalaDataType(dataType) + sqlContext.functionRegistry.registerFunction( + name, + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF21[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), scalaType, e)) + } + + def registerFunction(name: String, f: UDF22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], dataType: DataType) = { + val scalaType = asScalaDataType(dataType) + sqlContext.functionRegistry.registerFunction( + name, + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF22[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), scalaType, e)) + } + + // scalastyle:on +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala index 74f5630fbddf1..7e7bb2859bbcd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala @@ -118,7 +118,7 @@ private[sql] class BinaryColumnBuilder extends ComplexColumnBuilder(BINARY) private[sql] class GenericColumnBuilder extends ComplexColumnBuilder(GENERIC) private[sql] object ColumnBuilder { - val DEFAULT_INITIAL_BUFFER_SIZE = 10 * 1024 * 104 + val DEFAULT_INITIAL_BUFFER_SIZE = 1024 * 1024 private[columnar] def ensureFreeSpace(orig: ByteBuffer, size: Int) = { if (orig.remaining >= size) { @@ -154,6 +154,7 @@ private[sql] object ColumnBuilder { case STRING.typeId => new StringColumnBuilder case BINARY.typeId => new BinaryColumnBuilder case GENERIC.typeId => new GenericColumnBuilder + case TIMESTAMP.typeId => new TimestampColumnBuilder }).asInstanceOf[ColumnBuilder] builder.initialize(initialSize, columnName, useCompression) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala index d008806eedbe1..f631ee76fcd78 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala @@ -36,9 +36,9 @@ import org.apache.spark.sql.Row * }}} */ private[sql] trait NullableColumnBuilder extends ColumnBuilder { - private var nulls: ByteBuffer = _ + protected var nulls: ByteBuffer = _ + protected var nullCount: Int = _ private var pos: Int = _ - private var nullCount: Int = _ abstract override def initialize(initialSize: Int, columnName: String, useCompression: Boolean) { nulls = ByteBuffer.allocate(1024) @@ -78,4 +78,9 @@ private[sql] trait NullableColumnBuilder extends ColumnBuilder { buffer.rewind() buffer } + + protected def buildNonNulls(): ByteBuffer = { + nulls.limit(nulls.position()).rewind() + super.build() + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala index 4c6675c3c87bf..a5826bb033e41 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala @@ -19,7 +19,8 @@ package org.apache.spark.sql.columnar.compression import java.nio.{ByteBuffer, ByteOrder} -import org.apache.spark.sql.{Logging, Row} +import org.apache.spark.Logging +import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.types.NativeType import org.apache.spark.sql.columnar.{ColumnBuilder, NativeColumnBuilder} @@ -45,8 +46,6 @@ private[sql] trait CompressibleColumnBuilder[T <: NativeType] this: NativeColumnBuilder[T] with WithCompressionSchemes => - import CompressionScheme._ - var compressionEncoders: Seq[Encoder[T]] = _ abstract override def initialize(initialSize: Int, columnName: String, useCompression: Boolean) { @@ -80,28 +79,32 @@ private[sql] trait CompressibleColumnBuilder[T <: NativeType] } } - abstract override def build() = { - val rawBuffer = super.build() + override def build() = { + val nonNullBuffer = buildNonNulls() + val typeId = nonNullBuffer.getInt() val encoder: Encoder[T] = { val candidate = compressionEncoders.minBy(_.compressionRatio) if (isWorthCompressing(candidate)) candidate else PassThrough.encoder } - val headerSize = columnHeaderSize(rawBuffer) + // Header = column type ID + null count + null positions + val headerSize = 4 + 4 + nulls.limit() val compressedSize = if (encoder.compressedSize == 0) { - rawBuffer.limit - headerSize + nonNullBuffer.remaining() } else { encoder.compressedSize } - // Reserves 4 bytes for compression scheme ID val compressedBuffer = ByteBuffer + // Reserves 4 bytes for compression scheme ID .allocate(headerSize + 4 + compressedSize) .order(ByteOrder.nativeOrder) + // Write the header + .putInt(typeId) + .putInt(nullCount) + .put(nulls) - copyColumnHeader(rawBuffer, compressedBuffer) - - logger.info(s"Compressor for [$columnName]: $encoder, ratio: ${encoder.compressionRatio}") - encoder.compress(rawBuffer, compressedBuffer, columnType) + logInfo(s"Compressor for [$columnName]: $encoder, ratio: ${encoder.compressionRatio}") + encoder.compress(nonNullBuffer, compressedBuffer, columnType) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala index ba1810dd2ae66..7797f75177893 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala @@ -67,22 +67,6 @@ private[sql] object CompressionScheme { s"Unrecognized compression scheme type ID: $typeId")) } - def copyColumnHeader(from: ByteBuffer, to: ByteBuffer) { - // Writes column type ID - to.putInt(from.getInt()) - - // Writes null count - val nullCount = from.getInt() - to.putInt(nullCount) - - // Writes null positions - var i = 0 - while (i < nullCount) { - to.putInt(from.getInt()) - i += 1 - } - } - def columnHeaderSize(columnBuffer: ByteBuffer): Int = { val header = columnBuffer.duplicate().order(ByteOrder.nativeOrder) val nullCount = header.getInt(4) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala index c1ced8bfa404a..463a1d32d7fd7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala @@ -42,8 +42,8 @@ case class Aggregate( partial: Boolean, groupingExpressions: Seq[Expression], aggregateExpressions: Seq[NamedExpression], - child: SparkPlan)(@transient sqlContext: SQLContext) - extends UnaryNode with NoBind { + child: SparkPlan) + extends UnaryNode { override def requiredChildDistribution = if (partial) { @@ -56,8 +56,6 @@ case class Aggregate( } } - override def otherCopyArgs = sqlContext :: Nil - // HACK: Generators don't correctly preserve their output through serializations so we grab // out child's output attributes statically here. private[this] val childOutput = child.output @@ -138,7 +136,7 @@ case class Aggregate( i += 1 } } - val resultProjection = new Projection(resultExpressions, computedSchema) + val resultProjection = new InterpretedProjection(resultExpressions, computedSchema) val aggregateResults = new GenericMutableRow(computedAggregates.length) var i = 0 @@ -152,7 +150,7 @@ case class Aggregate( } else { child.execute().mapPartitions { iter => val hashTable = new HashMap[Row, Array[AggregateFunction]] - val groupingProjection = new MutableProjection(groupingExpressions, childOutput) + val groupingProjection = new InterpretedMutableProjection(groupingExpressions, childOutput) var currentRow: Row = null while (iter.hasNext) { @@ -175,7 +173,8 @@ case class Aggregate( private[this] val hashTableIter = hashTable.entrySet().iterator() private[this] val aggregateResults = new GenericMutableRow(computedAggregates.length) private[this] val resultProjection = - new MutableProjection(resultExpressions, computedSchema ++ namedGroups.map(_._2)) + new InterpretedMutableProjection( + resultExpressions, computedSchema ++ namedGroups.map(_._2)) private[this] val joinedRow = new JoinedRow override final def hasNext: Boolean = hashTableIter.hasNext diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 00010ef6e798a..77dc2ad733215 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -22,7 +22,7 @@ import org.apache.spark.{HashPartitioner, RangePartitioner, SparkConf} import org.apache.spark.rdd.ShuffledRDD import org.apache.spark.sql.{SQLContext, Row} import org.apache.spark.sql.catalyst.errors.attachTree -import org.apache.spark.sql.catalyst.expressions.{NoBind, MutableProjection, RowOrdering} +import org.apache.spark.sql.catalyst.expressions.RowOrdering import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.util.MutablePair @@ -31,7 +31,7 @@ import org.apache.spark.util.MutablePair * :: DeveloperApi :: */ @DeveloperApi -case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends UnaryNode with NoBind { +case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends UnaryNode { override def outputPartitioning = newPartitioning @@ -42,12 +42,14 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una case HashPartitioning(expressions, numPartitions) => // TODO: Eliminate redundant expressions in grouping key and value. val rdd = child.execute().mapPartitions { iter => - val hashExpressions = new MutableProjection(expressions, child.output) + @transient val hashExpressions = + newMutableProjection(expressions, child.output)() + val mutablePair = new MutablePair[Row, Row]() iter.map(r => mutablePair.update(hashExpressions(r), r)) } val part = new HashPartitioner(numPartitions) - val shuffled = new ShuffledRDD[Row, Row, Row, MutablePair[Row, Row]](rdd, part) + val shuffled = new ShuffledRDD[Row, Row, Row](rdd, part) shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false))) shuffled.map(_._2) @@ -60,7 +62,7 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una iter.map(row => mutablePair.update(row, null)) } val part = new RangePartitioner(numPartitions, rdd, ascending = true) - val shuffled = new ShuffledRDD[Row, Null, Null, MutablePair[Row, Null]](rdd, part) + val shuffled = new ShuffledRDD[Row, Null, Null](rdd, part) shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false))) shuffled.map(_._1) @@ -71,7 +73,7 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una iter.map(r => mutablePair.update(null, r)) } val partitioner = new HashPartitioner(1) - val shuffled = new ShuffledRDD[Null, Row, Row, MutablePair[Null, Row]](rdd, partitioner) + val shuffled = new ShuffledRDD[Null, Row, Row](rdd, partitioner) shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false))) shuffled.map(_._2) @@ -99,7 +101,7 @@ private[sql] case class AddExchange(sqlContext: SQLContext) extends Rule[SparkPl !operator.requiredChildDistribution.zip(operator.children).map { case (required, child) => val valid = child.outputPartitioning.satisfies(required) - logger.debug( + logDebug( s"${if (valid) "Valid" else "Invalid"} distribution," + s"required: $required current: ${child.outputPartitioning}") valid diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala index 47b3d00262dbb..c386fd121c5de 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala @@ -47,23 +47,26 @@ case class Generate( } } - override def output = + // This must be a val since the generator output expr ids are not preserved by serialization. + override val output = if (join) child.output ++ generatorOutput else generatorOutput + val boundGenerator = BindReferences.bindReference(generator, child.output) + override def execute() = { if (join) { child.execute().mapPartitions { iter => val nullValues = Seq.fill(generator.output.size)(Literal(null)) // Used to produce rows with no matches when outer = true. val outerProjection = - new Projection(child.output ++ nullValues, child.output) + newProjection(child.output ++ nullValues, child.output) val joinProjection = - new Projection(child.output ++ generator.output, child.output ++ generator.output) + newProjection(child.output ++ generator.output, child.output ++ generator.output) val joinedRow = new JoinedRow iter.flatMap {row => - val outputRows = generator.eval(row) + val outputRows = boundGenerator.eval(row) if (outer && outputRows.isEmpty) { outerProjection(row) :: Nil } else { @@ -72,7 +75,7 @@ case class Generate( } } } else { - child.execute().mapPartitions(iter => iter.flatMap(row => generator.eval(row))) + child.execute().mapPartitions(iter => iter.flatMap(row => boundGenerator.eval(row))) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala new file mode 100644 index 0000000000000..4a26934c49c93 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala @@ -0,0 +1,200 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.catalyst.types._ + +case class AggregateEvaluation( + schema: Seq[Attribute], + initialValues: Seq[Expression], + update: Seq[Expression], + result: Expression) + +/** + * :: DeveloperApi :: + * Alternate version of aggregation that leverages projection and thus code generation. + * Aggregations are converted into a set of projections from a aggregation buffer tuple back onto + * itself. Currently only used for simple aggregations like SUM, COUNT, or AVERAGE are supported. + * + * @param partial if true then aggregation is done partially on local data without shuffling to + * ensure all values where `groupingExpressions` are equal are present. + * @param groupingExpressions expressions that are evaluated to determine grouping. + * @param aggregateExpressions expressions that are computed for each group. + * @param child the input data source. + */ +@DeveloperApi +case class GeneratedAggregate( + partial: Boolean, + groupingExpressions: Seq[Expression], + aggregateExpressions: Seq[NamedExpression], + child: SparkPlan) + extends UnaryNode { + + override def requiredChildDistribution = + if (partial) { + UnspecifiedDistribution :: Nil + } else { + if (groupingExpressions == Nil) { + AllTuples :: Nil + } else { + ClusteredDistribution(groupingExpressions) :: Nil + } + } + + override def output = aggregateExpressions.map(_.toAttribute) + + override def execute() = { + val aggregatesToCompute = aggregateExpressions.flatMap { a => + a.collect { case agg: AggregateExpression => agg} + } + + val computeFunctions = aggregatesToCompute.map { + case c @ Count(expr) => + val currentCount = AttributeReference("currentCount", LongType, nullable = false)() + val initialValue = Literal(0L) + val updateFunction = If(IsNotNull(expr), Add(currentCount, Literal(1L)), currentCount) + val result = currentCount + + AggregateEvaluation(currentCount :: Nil, initialValue :: Nil, updateFunction :: Nil, result) + + case Sum(expr) => + val currentSum = AttributeReference("currentSum", expr.dataType, nullable = false)() + val initialValue = Cast(Literal(0L), expr.dataType) + + // Coalasce avoids double calculation... + // but really, common sub expression elimination would be better.... + val updateFunction = Coalesce(Add(expr, currentSum) :: currentSum :: Nil) + val result = currentSum + + AggregateEvaluation(currentSum :: Nil, initialValue :: Nil, updateFunction :: Nil, result) + + case a @ Average(expr) => + val currentCount = AttributeReference("currentCount", LongType, nullable = false)() + val currentSum = AttributeReference("currentSum", expr.dataType, nullable = false)() + val initialCount = Literal(0L) + val initialSum = Cast(Literal(0L), expr.dataType) + val updateCount = If(IsNotNull(expr), Add(currentCount, Literal(1L)), currentCount) + val updateSum = Coalesce(Add(expr, currentSum) :: currentSum :: Nil) + + val result = Divide(Cast(currentSum, DoubleType), Cast(currentCount, DoubleType)) + + AggregateEvaluation( + currentCount :: currentSum :: Nil, + initialCount :: initialSum :: Nil, + updateCount :: updateSum :: Nil, + result + ) + } + + val computationSchema = computeFunctions.flatMap(_.schema) + + val resultMap: Map[Long, Expression] = aggregatesToCompute.zip(computeFunctions).map { + case (agg, func) => agg.id -> func.result + }.toMap + + val namedGroups = groupingExpressions.zipWithIndex.map { + case (ne: NamedExpression, _) => (ne, ne) + case (e, i) => (e, Alias(e, s"GroupingExpr$i")()) + } + + val groupMap: Map[Expression, Attribute] = + namedGroups.map { case (k, v) => k -> v.toAttribute}.toMap + + // The set of expressions that produce the final output given the aggregation buffer and the + // grouping expressions. + val resultExpressions = aggregateExpressions.map(_.transform { + case e: Expression if resultMap.contains(e.id) => resultMap(e.id) + case e: Expression if groupMap.contains(e) => groupMap(e) + }) + + child.execute().mapPartitions { iter => + // Builds a new custom class for holding the results of aggregation for a group. + val initialValues = computeFunctions.flatMap(_.initialValues) + val newAggregationBuffer = newProjection(initialValues, child.output) + log.info(s"Initial values: ${initialValues.mkString(",")}") + + // A projection that computes the group given an input tuple. + val groupProjection = newProjection(groupingExpressions, child.output) + log.info(s"Grouping Projection: ${groupingExpressions.mkString(",")}") + + // A projection that is used to update the aggregate values for a group given a new tuple. + // This projection should be targeted at the current values for the group and then applied + // to a joined row of the current values with the new input row. + val updateExpressions = computeFunctions.flatMap(_.update) + val updateSchema = computeFunctions.flatMap(_.schema) ++ child.output + val updateProjection = newMutableProjection(updateExpressions, updateSchema)() + log.info(s"Update Expressions: ${updateExpressions.mkString(",")}") + + // A projection that produces the final result, given a computation. + val resultProjectionBuilder = + newMutableProjection( + resultExpressions, + (namedGroups.map(_._2.toAttribute) ++ computationSchema).toSeq) + log.info(s"Result Projection: ${resultExpressions.mkString(",")}") + + val joinedRow = new JoinedRow + + if (groupingExpressions.isEmpty) { + // TODO: Codegening anything other than the updateProjection is probably over kill. + val buffer = newAggregationBuffer(EmptyRow).asInstanceOf[MutableRow] + var currentRow: Row = null + updateProjection.target(buffer) + + while (iter.hasNext) { + currentRow = iter.next() + updateProjection(joinedRow(buffer, currentRow)) + } + + val resultProjection = resultProjectionBuilder() + Iterator(resultProjection(buffer)) + } else { + val buffers = new java.util.HashMap[Row, MutableRow]() + + var currentRow: Row = null + while (iter.hasNext) { + currentRow = iter.next() + val currentGroup = groupProjection(currentRow) + var currentBuffer = buffers.get(currentGroup) + if (currentBuffer == null) { + currentBuffer = newAggregationBuffer(EmptyRow).asInstanceOf[MutableRow] + buffers.put(currentGroup, currentBuffer) + } + // Target the projection at the current aggregation buffer and then project the updated + // values. + updateProjection.target(currentBuffer)(joinedRow(currentBuffer, currentRow)) + } + + new Iterator[Row] { + private[this] val resultIterator = buffers.entrySet.iterator() + private[this] val resultProjection = resultProjectionBuilder() + + def hasNext = resultIterator.hasNext + + def next() = { + val currentGroup = resultIterator.next() + resultProjection(joinedRow(currentGroup.getKey, currentGroup.getValue)) + } + } + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 27dc091b85812..21cbbc9772a00 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -18,22 +18,55 @@ package org.apache.spark.sql.execution import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.Logging import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{Logging, Row} + + +import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation -import org.apache.spark.sql.catalyst.expressions.GenericRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.QueryPlan -import org.apache.spark.sql.catalyst.plans.logical.BaseRelation +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.plans.physical._ + +object SparkPlan { + protected[sql] val currentContext = new ThreadLocal[SQLContext]() +} + /** * :: DeveloperApi :: */ @DeveloperApi -abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging { +abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializable { self: Product => + /** + * A handle to the SQL Context that was used to create this plan. Since many operators need + * access to the sqlContext for RDD operations or configuration this field is automatically + * populated by the query planning infrastructure. + */ + @transient + protected val sqlContext = SparkPlan.currentContext.get() + + protected def sparkContext = sqlContext.sparkContext + + // sqlContext will be null when we are being deserialized on the slaves. In this instance + // the value of codegenEnabled will be set by the desserializer after the constructor has run. + val codegenEnabled: Boolean = if (sqlContext != null) { + sqlContext.codegenEnabled + } else { + false + } + + /** Overridden make copy also propogates sqlContext to copied plan. */ + override def makeCopy(newArgs: Array[AnyRef]): this.type = { + SparkPlan.currentContext.set(sqlContext) + super.makeCopy(newArgs) + } + // TODO: Move to `DistributedPlan` /** Specifies how data is partitioned across different nodes in the cluster. */ def outputPartitioning: Partitioning = UnknownPartitioning(0) // TODO: WRONG WIDTH! @@ -51,8 +84,46 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging { */ def executeCollect(): Array[Row] = execute().map(_.copy()).collect() - protected def buildRow(values: Seq[Any]): Row = - new GenericRow(values.toArray) + protected def newProjection( + expressions: Seq[Expression], inputSchema: Seq[Attribute]): Projection = { + log.debug( + s"Creating Projection: $expressions, inputSchema: $inputSchema, codegen:$codegenEnabled") + if (codegenEnabled) { + GenerateProjection(expressions, inputSchema) + } else { + new InterpretedProjection(expressions, inputSchema) + } + } + + protected def newMutableProjection( + expressions: Seq[Expression], + inputSchema: Seq[Attribute]): () => MutableProjection = { + log.debug( + s"Creating MutableProj: $expressions, inputSchema: $inputSchema, codegen:$codegenEnabled") + if(codegenEnabled) { + GenerateMutableProjection(expressions, inputSchema) + } else { + () => new InterpretedMutableProjection(expressions, inputSchema) + } + } + + + protected def newPredicate( + expression: Expression, inputSchema: Seq[Attribute]): (Row) => Boolean = { + if (codegenEnabled) { + GeneratePredicate(expression, inputSchema) + } else { + InterpretedPredicate(expression, inputSchema) + } + } + + protected def newOrdering(order: Seq[SortOrder], inputSchema: Seq[Attribute]): Ordering[Row] = { + if (codegenEnabled) { + GenerateOrdering(order, inputSchema) + } else { + new RowOrdering(order, inputSchema) + } + } } /** @@ -66,8 +137,8 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging { * linking. */ @DeveloperApi -case class SparkLogicalPlan(alreadyPlanned: SparkPlan, tableName: String = "SparkLogicalPlan") - extends BaseRelation with MultiInstanceRelation { +case class SparkLogicalPlan(alreadyPlanned: SparkPlan)(@transient sqlContext: SQLContext) + extends LogicalPlan with MultiInstanceRelation { def output = alreadyPlanned.output override def references = Set.empty @@ -78,9 +149,15 @@ case class SparkLogicalPlan(alreadyPlanned: SparkPlan, tableName: String = "Spar alreadyPlanned match { case ExistingRdd(output, rdd) => ExistingRdd(output.map(_.newInstance), rdd) case _ => sys.error("Multiple instance of the same relation detected.") - }, tableName) - .asInstanceOf[this.type] + })(sqlContext).asInstanceOf[this.type] } + + @transient override lazy val statistics = Statistics( + // TODO: Instead of returning a default value here, find a way to return a meaningful size + // estimate for RDDs. See PR 1238 for more discussions. + sizeInBytes = BigInt(sqlContext.defaultSizeInBytes) + ) + } private[sql] trait LeafNode extends SparkPlan with trees.LeafNode[SparkPlan] { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index c078e71fe0290..f0c958fdb537f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -21,7 +21,7 @@ import org.apache.spark.sql.{SQLContext, execution} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.logical.{BaseRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.columnar.{InMemoryRelation, InMemoryColumnarTableScan} import org.apache.spark.sql.parquet._ @@ -39,7 +39,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { // no predicate can be evaluated by matching hash keys case logical.Join(left, right, LeftSemi, condition) => execution.LeftSemiJoinBNL( - planLater(left), planLater(right), condition)(sqlContext) :: Nil + planLater(left), planLater(right), condition) :: Nil case _ => Nil } } @@ -47,9 +47,19 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { /** * Uses the ExtractEquiJoinKeys pattern to find joins where at least some of the predicates can be * evaluated by matching hash keys. + * + * This strategy applies a simple optimization based on the estimates of the physical sizes of + * the two join sides. When planning a [[execution.BroadcastHashJoin]], if one side has an + * estimated physical size smaller than the user-settable threshold + * [[org.apache.spark.sql.SQLConf.AUTO_BROADCASTJOIN_THRESHOLD]], the planner would mark it as the + * ''build'' relation and mark the other relation as the ''stream'' side. The build table will be + * ''broadcasted'' to all of the executors involved in the join, as a + * [[org.apache.spark.broadcast.Broadcast]] object. If both estimates exceed the threshold, they + * will instead be used to decide the build side in a [[execution.ShuffledHashJoin]]. */ object HashJoin extends Strategy with PredicateHelper { - private[this] def broadcastHashJoin( + + private[this] def makeBroadcastHashJoin( leftKeys: Seq[Expression], rightKeys: Seq[Expression], left: LogicalPlan, @@ -57,102 +67,102 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { condition: Option[Expression], side: BuildSide) = { val broadcastHashJoin = execution.BroadcastHashJoin( - leftKeys, rightKeys, side, planLater(left), planLater(right))(sqlContext) + leftKeys, rightKeys, side, planLater(left), planLater(right)) condition.map(Filter(_, broadcastHashJoin)).getOrElse(broadcastHashJoin) :: Nil } - def broadcastTables: Seq[String] = sqlContext.joinBroadcastTables.split(",").toBuffer - def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case ExtractEquiJoinKeys( - Inner, - leftKeys, - rightKeys, - condition, - left, - right @ PhysicalOperation(_, _, b: BaseRelation)) - if broadcastTables.contains(b.tableName) => - broadcastHashJoin(leftKeys, rightKeys, left, right, condition, BuildRight) + case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) + if sqlContext.autoBroadcastJoinThreshold > 0 && + right.statistics.sizeInBytes <= sqlContext.autoBroadcastJoinThreshold => + makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, BuildRight) - case ExtractEquiJoinKeys( - Inner, - leftKeys, - rightKeys, - condition, - left @ PhysicalOperation(_, _, b: BaseRelation), - right) - if broadcastTables.contains(b.tableName) => - broadcastHashJoin(leftKeys, rightKeys, left, right, condition, BuildLeft) + case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) + if sqlContext.autoBroadcastJoinThreshold > 0 && + left.statistics.sizeInBytes <= sqlContext.autoBroadcastJoinThreshold => + makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, BuildLeft) case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) => + val buildSide = + if (right.statistics.sizeInBytes <= left.statistics.sizeInBytes) { + BuildRight + } else { + BuildLeft + } val hashJoin = execution.ShuffledHashJoin( - leftKeys, rightKeys, BuildRight, planLater(left), planLater(right)) + leftKeys, rightKeys, buildSide, planLater(left), planLater(right)) condition.map(Filter(_, hashJoin)).getOrElse(hashJoin) :: Nil + case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) => + execution.HashOuterJoin( + leftKeys, rightKeys, joinType, condition, planLater(left), planLater(right)) :: Nil + case _ => Nil } } - object PartialAggregation extends Strategy { + object HashAggregation extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case logical.Aggregate(groupingExpressions, aggregateExpressions, child) => - // Collect all aggregate expressions. - val allAggregates = - aggregateExpressions.flatMap(_ collect { case a: AggregateExpression => a }) - // Collect all aggregate expressions that can be computed partially. - val partialAggregates = - aggregateExpressions.flatMap(_ collect { case p: PartialAggregate => p }) - - // Only do partial aggregation if supported by all aggregate expressions. - if (allAggregates.size == partialAggregates.size) { - // Create a map of expressions to their partial evaluations for all aggregate expressions. - val partialEvaluations: Map[Long, SplitEvaluation] = - partialAggregates.map(a => (a.id, a.asPartial)).toMap - - // We need to pass all grouping expressions though so the grouping can happen a second - // time. However some of them might be unnamed so we alias them allowing them to be - // referenced in the second aggregation. - val namedGroupingExpressions: Map[Expression, NamedExpression] = groupingExpressions.map { - case n: NamedExpression => (n, n) - case other => (other, Alias(other, "PartialGroup")()) - }.toMap - - // Replace aggregations with a new expression that computes the result from the already - // computed partial evaluations and grouping values. - val rewrittenAggregateExpressions = aggregateExpressions.map(_.transformUp { - case e: Expression if partialEvaluations.contains(e.id) => - partialEvaluations(e.id).finalEvaluation - case e: Expression if namedGroupingExpressions.contains(e) => - namedGroupingExpressions(e).toAttribute - }).asInstanceOf[Seq[NamedExpression]] - - val partialComputation = - (namedGroupingExpressions.values ++ - partialEvaluations.values.flatMap(_.partialEvaluations)).toSeq + // Aggregations that can be performed in two phases, before and after the shuffle. - // Construct two phased aggregation. - execution.Aggregate( + // Cases where all aggregates can be codegened. + case PartialAggregation( + namedGroupingAttributes, + rewrittenAggregateExpressions, + groupingExpressions, + partialComputation, + child) + if canBeCodeGened( + allAggregates(partialComputation) ++ + allAggregates(rewrittenAggregateExpressions)) && + codegenEnabled => + execution.GeneratedAggregate( partial = false, - namedGroupingExpressions.values.map(_.toAttribute).toSeq, + namedGroupingAttributes, rewrittenAggregateExpressions, - execution.Aggregate( + execution.GeneratedAggregate( partial = true, groupingExpressions, partialComputation, - planLater(child))(sqlContext))(sqlContext) :: Nil - } else { - Nil - } + planLater(child))) :: Nil + + // Cases where some aggregate can not be codegened + case PartialAggregation( + namedGroupingAttributes, + rewrittenAggregateExpressions, + groupingExpressions, + partialComputation, + child) => + execution.Aggregate( + partial = false, + namedGroupingAttributes, + rewrittenAggregateExpressions, + execution.Aggregate( + partial = true, + groupingExpressions, + partialComputation, + planLater(child))) :: Nil + case _ => Nil } + + def canBeCodeGened(aggs: Seq[AggregateExpression]) = !aggs.exists { + case _: Sum | _: Count => false + case _ => true + } + + def allAggregates(exprs: Seq[Expression]) = + exprs.flatMap(_.collect { case a: AggregateExpression => a }) } object BroadcastNestedLoopJoin extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case logical.Join(left, right, joinType, condition) => + val buildSide = + if (right.statistics.sizeInBytes <= left.statistics.sizeInBytes) BuildRight else BuildLeft execution.BroadcastNestedLoopJoin( - planLater(left), planLater(right), joinType, condition)(sqlContext) :: Nil + planLater(left), planLater(right), buildSide, joinType, condition) :: Nil case _ => Nil } } @@ -171,16 +181,10 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { protected lazy val singleRowRdd = sparkContext.parallelize(Seq(new GenericRow(Array[Any]()): Row), 1) - def convertToCatalyst(a: Any): Any = a match { - case s: Seq[Any] => s.map(convertToCatalyst) - case p: Product => new GenericRow(p.productIterator.map(convertToCatalyst).toArray) - case other => other - } - object TakeOrdered extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case logical.Limit(IntegerLiteral(limit), logical.Sort(order, child)) => - execution.TakeOrdered(limit, order, planLater(child))(sqlContext) :: Nil + execution.TakeOrdered(limit, order, planLater(child)) :: Nil case _ => Nil } } @@ -190,11 +194,11 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { // TODO: need to support writing to other types of files. Unify the below code paths. case logical.WriteToFile(path, child) => val relation = - ParquetRelation.create(path, child, sparkContext.hadoopConfiguration) + ParquetRelation.create(path, child, sparkContext.hadoopConfiguration, sqlContext) // Note: overwrite=false because otherwise the metadata we just created will be deleted - InsertIntoParquetTable(relation, planLater(child), overwrite=false)(sqlContext) :: Nil + InsertIntoParquetTable(relation, planLater(child), overwrite = false) :: Nil case logical.InsertIntoTable(table: ParquetRelation, partition, child, overwrite) => - InsertIntoParquetTable(table, planLater(child), overwrite)(sqlContext) :: Nil + InsertIntoParquetTable(table, planLater(child), overwrite) :: Nil case PhysicalOperation(projectList, filters: Seq[Expression], relation: ParquetRelation) => val prunePushedDownFilters = if (sparkContext.conf.getBoolean(ParquetFilters.PARQUET_FILTER_PUSHDOWN_ENABLED, true)) { @@ -223,7 +227,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { projectList, filters, prunePushedDownFilters, - ParquetTableScan(_, relation, filters)(sqlContext)) :: Nil + ParquetTableScan(_, relation, filters)) :: Nil case _ => Nil } @@ -261,20 +265,19 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.Filter(condition, child) => execution.Filter(condition, planLater(child)) :: Nil case logical.Aggregate(group, agg, child) => - execution.Aggregate(partial = false, group, agg, planLater(child))(sqlContext) :: Nil + execution.Aggregate(partial = false, group, agg, planLater(child)) :: Nil case logical.Sample(fraction, withReplacement, seed, child) => execution.Sample(fraction, withReplacement, seed, planLater(child)) :: Nil case logical.LocalRelation(output, data) => - val dataAsRdd = - sparkContext.parallelize(data.map(r => - new GenericRow(r.productIterator.map(convertToCatalyst).toArray): Row)) - execution.ExistingRdd(output, dataAsRdd) :: Nil + ExistingRdd( + output, + ExistingRdd.productToRowRdd(sparkContext.parallelize(data, numPartitions))) :: Nil case logical.Limit(IntegerLiteral(limit), child) => - execution.Limit(limit, planLater(child))(sqlContext) :: Nil + execution.Limit(limit, planLater(child)) :: Nil case Unions(unionChildren) => - execution.Union(unionChildren.map(planLater))(sqlContext) :: Nil - case logical.Except(left,right) => - execution.Except(planLater(left),planLater(right)) :: Nil + execution.Union(unionChildren.map(planLater)) :: Nil + case logical.Except(left, right) => + execution.Except(planLater(left), planLater(right)) :: Nil case logical.Intersect(left, right) => execution.Intersect(planLater(left), planLater(right)) :: Nil case logical.Generate(generator, join, outer, _, child) => @@ -283,7 +286,9 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.ExistingRdd(Nil, singleRowRdd) :: Nil case logical.Repartition(expressions, child) => execution.Exchange(HashPartitioning(expressions, numPartitions), planLater(child)) :: Nil - case SparkLogicalPlan(existingPlan, _) => existingPlan :: Nil + case e @ EvaluatePython(udf, child) => + BatchPythonEvaluation(udf, e.output, planLater(child)) :: Nil + case SparkLogicalPlan(existingPlan) => existingPlan :: Nil case _ => Nil } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 966d8f95fc83c..0027f3cf1fc79 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -37,9 +37,11 @@ import org.apache.spark.util.MutablePair case class Project(projectList: Seq[NamedExpression], child: SparkPlan) extends UnaryNode { override def output = projectList.map(_.toAttribute) - override def execute() = child.execute().mapPartitions { iter => - @transient val reusableProjection = new MutableProjection(projectList) - iter.map(reusableProjection) + @transient lazy val buildProjection = newMutableProjection(projectList, child.output) + + def execute() = child.execute().mapPartitions { iter => + val resuableProjection = buildProjection() + iter.map(resuableProjection) } } @@ -50,8 +52,10 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan) extends case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode { override def output = child.output - override def execute() = child.execute().mapPartitions { iter => - iter.filter(condition.eval(_).asInstanceOf[Boolean]) + @transient lazy val conditionEvaluator = newPredicate(condition, child.output) + + def execute() = child.execute().mapPartitions { iter => + iter.filter(conditionEvaluator) } } @@ -72,12 +76,10 @@ case class Sample(fraction: Double, withReplacement: Boolean, seed: Long, child: * :: DeveloperApi :: */ @DeveloperApi -case class Union(children: Seq[SparkPlan])(@transient sqlContext: SQLContext) extends SparkPlan { +case class Union(children: Seq[SparkPlan]) extends SparkPlan { // TODO: attributes output by union should be distinct for nullability purposes override def output = children.head.output - override def execute() = sqlContext.sparkContext.union(children.map(_.execute())) - - override def otherCopyArgs = sqlContext :: Nil + override def execute() = sparkContext.union(children.map(_.execute())) } /** @@ -89,13 +91,11 @@ case class Union(children: Seq[SparkPlan])(@transient sqlContext: SQLContext) ex * repartition all the data to a single partition to compute the global limit. */ @DeveloperApi -case class Limit(limit: Int, child: SparkPlan)(@transient sqlContext: SQLContext) +case class Limit(limit: Int, child: SparkPlan) extends UnaryNode { // TODO: Implement a partition local limit, and use a strategy to generate the proper limit plan: // partition local limit -> exchange into one partition -> partition local limit again - override def otherCopyArgs = sqlContext :: Nil - override def output = child.output /** @@ -148,7 +148,7 @@ case class Limit(limit: Int, child: SparkPlan)(@transient sqlContext: SQLContext iter.take(limit).map(row => mutablePair.update(false, row)) } val part = new HashPartitioner(1) - val shuffled = new ShuffledRDD[Boolean, Row, Row, MutablePair[Boolean, Row]](rdd, part) + val shuffled = new ShuffledRDD[Boolean, Row, Row](rdd, part) shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false))) shuffled.mapPartitions(_.take(limit).map(_._2)) } @@ -161,20 +161,18 @@ case class Limit(limit: Int, child: SparkPlan)(@transient sqlContext: SQLContext * Spark's top operator does the opposite in ordering so we name it TakeOrdered to avoid confusion. */ @DeveloperApi -case class TakeOrdered(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan) - (@transient sqlContext: SQLContext) extends UnaryNode { - override def otherCopyArgs = sqlContext :: Nil +case class TakeOrdered(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan) extends UnaryNode { override def output = child.output - @transient - lazy val ordering = new RowOrdering(sortOrder) + val ordering = new RowOrdering(sortOrder, child.output) + // TODO: Is this copying for no reason? override def executeCollect() = child.execute().map(_.copy()).takeOrdered(limit)(ordering) // TODO: Terminal split should be implemented differently from non-terminal split. // TODO: Pick num splits based on |limit|. - override def execute() = sqlContext.sparkContext.makeRDD(executeCollect(), 1) + override def execute() = sparkContext.makeRDD(executeCollect(), 1) } /** @@ -189,15 +187,13 @@ case class Sort( override def requiredChildDistribution = if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil - @transient - lazy val ordering = new RowOrdering(sortOrder) override def execute() = attachTree(this, "sort") { - // TODO: Optimize sorting operation? child.execute() - .mapPartitions( - iterator => iterator.map(_.copy()).toArray.sorted(ordering).iterator, - preservesPartitioning = true) + .mapPartitions( { iterator => + val ordering = newOrdering(sortOrder, child.output) + iterator.map(_.copy()).toArray.sorted(ordering).iterator + }, preservesPartitioning = true) } override def output = child.output diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala index 98d2f89c8ae71..38f37564f1788 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala @@ -17,12 +17,13 @@ package org.apache.spark.sql.execution +import org.apache.spark.Logging import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericRow} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.sql.{Row, SQLConf, SQLContext} trait Command { /** @@ -44,28 +45,53 @@ trait Command { case class SetCommand( key: Option[String], value: Option[String], output: Seq[Attribute])( @transient context: SQLContext) - extends LeafNode with Command { + extends LeafNode with Command with Logging { - override protected[sql] lazy val sideEffectResult: Seq[(String, String)] = (key, value) match { + override protected[sql] lazy val sideEffectResult: Seq[String] = (key, value) match { // Set value for key k. case (Some(k), Some(v)) => - context.set(k, v) - Array(k -> v) + if (k == SQLConf.Deprecated.MAPRED_REDUCE_TASKS) { + logWarning(s"Property ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS} is deprecated, " + + s"automatically converted to ${SQLConf.SHUFFLE_PARTITIONS} instead.") + context.setConf(SQLConf.SHUFFLE_PARTITIONS, v) + Array(s"${SQLConf.SHUFFLE_PARTITIONS}=$v") + } else { + context.setConf(k, v) + Array(s"$k=$v") + } // Query the value bound to key k. case (Some(k), _) => - Array(k -> context.getOption(k).getOrElse("")) + // TODO (lian) This is just a workaround to make the Simba ODBC driver work. + // Should remove this once we get the ODBC driver updated. + if (k == "-v") { + val hiveJars = Seq( + "hive-exec-0.12.0.jar", + "hive-service-0.12.0.jar", + "hive-common-0.12.0.jar", + "hive-hwi-0.12.0.jar", + "hive-0.12.0.jar").mkString(":") + + Array( + "system:java.class.path=" + hiveJars, + "system:sun.java.command=shark.SharkServer2") + } + else { + Array(s"$k=${context.getConf(k, "")}") + } // Query all key-value pairs that are set in the SQLConf of the context. case (None, None) => - context.getAll + context.getAllConfs.map { case (k, v) => + s"$k=$v" + }.toSeq case _ => throw new IllegalArgumentException() } def execute(): RDD[Row] = { - val rows = sideEffectResult.map { case (k, v) => new GenericRow(Array[Any](k, v)) } + val rows = sideEffectResult.map { line => new GenericRow(Array[Any](line)) } context.sparkContext.parallelize(rows, 1) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala index c6fbd6d2f6930..f31df051824d7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala @@ -41,13 +41,13 @@ package object debug { */ @DeveloperApi implicit class DebugQuery(query: SchemaRDD) { - def debug(implicit sc: SparkContext): Unit = { + def debug(): Unit = { val plan = query.queryExecution.executedPlan val visited = new collection.mutable.HashSet[Long]() val debugPlan = plan transform { case s: SparkPlan if !visited.contains(s.id) => visited += s.id - DebugNode(sc, s) + DebugNode(s) } println(s"Results returned: ${debugPlan.execute().count()}") debugPlan.foreach { @@ -57,9 +57,7 @@ package object debug { } } - private[sql] case class DebugNode( - @transient sparkContext: SparkContext, - child: SparkPlan) extends UnaryNode { + private[sql] case class DebugNode(child: SparkPlan) extends UnaryNode { def references = Set.empty def output = child.output @@ -107,7 +105,9 @@ package object debug { var i = 0 while (i < numColumns) { val value = currentRow(i) - columnStats(i).elementTypes += HashSet(value.getClass.getName) + if (value != null) { + columnStats(i).elementTypes += HashSet(value.getClass.getName) + } i += 1 } currentRow diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala index 7d1f11caae838..51bb61530744c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala @@ -38,6 +38,8 @@ case object BuildLeft extends BuildSide case object BuildRight extends BuildSide trait HashJoin { + self: SparkPlan => + val leftKeys: Seq[Expression] val rightKeys: Seq[Expression] val buildSide: BuildSide @@ -56,9 +58,9 @@ trait HashJoin { def output = left.output ++ right.output - @transient lazy val buildSideKeyGenerator = new Projection(buildKeys, buildPlan.output) + @transient lazy val buildSideKeyGenerator = newProjection(buildKeys, buildPlan.output) @transient lazy val streamSideKeyGenerator = - () => new MutableProjection(streamedKeys, streamedPlan.output) + newMutableProjection(streamedKeys, streamedPlan.output) def joinIterators(buildIter: Iterator[Row], streamIter: Iterator[Row]): Iterator[Row] = { // TODO: Use Spark's HashMap implementation. @@ -70,7 +72,7 @@ trait HashJoin { while (buildIter.hasNext) { currentRow = buildIter.next() val rowKey = buildSideKeyGenerator(currentRow) - if(!rowKey.anyNull) { + if (!rowKey.anyNull) { val existingMatchList = hashTable.get(rowKey) val matchList = if (existingMatchList == null) { val newMatchList = new ArrayBuffer[Row]() @@ -134,6 +136,190 @@ trait HashJoin { } } +/** + * Constant Value for Binary Join Node + */ +object HashOuterJoin { + val DUMMY_LIST = Seq[Row](null) + val EMPTY_LIST = Seq[Row]() +} + +/** + * :: DeveloperApi :: + * Performs a hash based outer join for two child relations by shuffling the data using + * the join keys. This operator requires loading the associated partition in both side into memory. + */ +@DeveloperApi +case class HashOuterJoin( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + joinType: JoinType, + condition: Option[Expression], + left: SparkPlan, + right: SparkPlan) extends BinaryNode { + + override def outputPartitioning: Partitioning = joinType match { + case LeftOuter => left.outputPartitioning + case RightOuter => right.outputPartitioning + case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions) + case x => throw new Exception(s"HashOuterJoin should not take $x as the JoinType") + } + + override def requiredChildDistribution = + ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil + + def output = left.output ++ right.output + + // TODO we need to rewrite all of the iterators with our own implementation instead of the Scala + // iterator for performance purpose. + + private[this] def leftOuterIterator( + key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row]): Iterator[Row] = { + val joinedRow = new JoinedRow() + val rightNullRow = new GenericRow(right.output.length) + val boundCondition = + condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true) + + leftIter.iterator.flatMap { l => + joinedRow.withLeft(l) + var matched = false + (if (!key.anyNull) rightIter.collect { case r if (boundCondition(joinedRow.withRight(r))) => + matched = true + joinedRow.copy + } else { + Nil + }) ++ HashOuterJoin.DUMMY_LIST.filter(_ => !matched).map( _ => { + // HashOuterJoin.DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row, + // as we don't know whether we need to append it until finish iterating all of the + // records in right side. + // If we didn't get any proper row, then append a single row with empty right + joinedRow.withRight(rightNullRow).copy + }) + } + } + + private[this] def rightOuterIterator( + key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row]): Iterator[Row] = { + val joinedRow = new JoinedRow() + val leftNullRow = new GenericRow(left.output.length) + val boundCondition = + condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true) + + rightIter.iterator.flatMap { r => + joinedRow.withRight(r) + var matched = false + (if (!key.anyNull) leftIter.collect { case l if (boundCondition(joinedRow.withLeft(l))) => + matched = true + joinedRow.copy + } else { + Nil + }) ++ HashOuterJoin.DUMMY_LIST.filter(_ => !matched).map( _ => { + // HashOuterJoin.DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row, + // as we don't know whether we need to append it until finish iterating all of the + // records in left side. + // If we didn't get any proper row, then append a single row with empty left. + joinedRow.withLeft(leftNullRow).copy + }) + } + } + + private[this] def fullOuterIterator( + key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row]): Iterator[Row] = { + val joinedRow = new JoinedRow() + val leftNullRow = new GenericRow(left.output.length) + val rightNullRow = new GenericRow(right.output.length) + val boundCondition = + condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true) + + if (!key.anyNull) { + // Store the positions of records in right, if one of its associated row satisfy + // the join condition. + val rightMatchedSet = scala.collection.mutable.Set[Int]() + leftIter.iterator.flatMap[Row] { l => + joinedRow.withLeft(l) + var matched = false + rightIter.zipWithIndex.collect { + // 1. For those matched (satisfy the join condition) records with both sides filled, + // append them directly + + case (r, idx) if (boundCondition(joinedRow.withRight(r)))=> { + matched = true + // if the row satisfy the join condition, add its index into the matched set + rightMatchedSet.add(idx) + joinedRow.copy + } + } ++ HashOuterJoin.DUMMY_LIST.filter(_ => !matched).map( _ => { + // 2. For those unmatched records in left, append additional records with empty right. + + // HashOuterJoin.DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row, + // as we don't know whether we need to append it until finish iterating all + // of the records in right side. + // If we didn't get any proper row, then append a single row with empty right. + joinedRow.withRight(rightNullRow).copy + }) + } ++ rightIter.zipWithIndex.collect { + // 3. For those unmatched records in right, append additional records with empty left. + + // Re-visiting the records in right, and append additional row with empty left, if its not + // in the matched set. + case (r, idx) if (!rightMatchedSet.contains(idx)) => { + joinedRow(leftNullRow, r).copy + } + } + } else { + leftIter.iterator.map[Row] { l => + joinedRow(l, rightNullRow).copy + } ++ rightIter.iterator.map[Row] { r => + joinedRow(leftNullRow, r).copy + } + } + } + + private[this] def buildHashTable( + iter: Iterator[Row], keyGenerator: Projection): Map[Row, ArrayBuffer[Row]] = { + // TODO: Use Spark's HashMap implementation. + val hashTable = scala.collection.mutable.Map[Row, ArrayBuffer[Row]]() + while (iter.hasNext) { + val currentRow = iter.next() + val rowKey = keyGenerator(currentRow) + + val existingMatchList = hashTable.getOrElseUpdate(rowKey, {new ArrayBuffer[Row]()}) + existingMatchList += currentRow.copy() + } + + hashTable.toMap[Row, ArrayBuffer[Row]] + } + + def execute() = { + left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => + // TODO this probably can be replaced by external sort (sort merged join?) + // Build HashMap for current partition in left relation + val leftHashTable = buildHashTable(leftIter, newProjection(leftKeys, left.output)) + // Build HashMap for current partition in right relation + val rightHashTable = buildHashTable(rightIter, newProjection(rightKeys, right.output)) + + val boundCondition = + condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true) + joinType match { + case LeftOuter => leftHashTable.keysIterator.flatMap { key => + leftOuterIterator(key, leftHashTable.getOrElse(key, HashOuterJoin.EMPTY_LIST), + rightHashTable.getOrElse(key, HashOuterJoin.EMPTY_LIST)) + } + case RightOuter => rightHashTable.keysIterator.flatMap { key => + rightOuterIterator(key, leftHashTable.getOrElse(key, HashOuterJoin.EMPTY_LIST), + rightHashTable.getOrElse(key, HashOuterJoin.EMPTY_LIST)) + } + case FullOuter => (leftHashTable.keySet ++ rightHashTable.keySet).iterator.flatMap { key => + fullOuterIterator(key, + leftHashTable.getOrElse(key, HashOuterJoin.EMPTY_LIST), + rightHashTable.getOrElse(key, HashOuterJoin.EMPTY_LIST)) + } + case x => throw new Exception(s"HashOuterJoin should not take $x as the JoinType") + } + } + } +} + /** * :: DeveloperApi :: * Performs an inner hash join of two child relations by first shuffling the data using the join @@ -187,7 +373,7 @@ case class LeftSemiJoinHash( while (buildIter.hasNext) { currentRow = buildIter.next() val rowKey = buildSideKeyGenerator(currentRow) - if(!rowKey.anyNull) { + if (!rowKey.anyNull) { val keyExists = hashSet.contains(rowKey) if (!keyExists) { hashSet.add(rowKey) @@ -217,18 +403,16 @@ case class BroadcastHashJoin( rightKeys: Seq[Expression], buildSide: BuildSide, left: SparkPlan, - right: SparkPlan)(@transient sqlContext: SQLContext) extends BinaryNode with HashJoin { - - override def otherCopyArgs = sqlContext :: Nil + right: SparkPlan) extends BinaryNode with HashJoin { - override def outputPartitioning: Partitioning = left.outputPartitioning + override def outputPartitioning: Partitioning = streamedPlan.outputPartitioning override def requiredChildDistribution = UnspecifiedDistribution :: UnspecifiedDistribution :: Nil @transient lazy val broadcastFuture = future { - sqlContext.sparkContext.broadcast(buildPlan.executeCollect()) + sparkContext.broadcast(buildPlan.executeCollect()) } def execute() = { @@ -248,14 +432,11 @@ case class BroadcastHashJoin( @DeveloperApi case class LeftSemiJoinBNL( streamed: SparkPlan, broadcast: SparkPlan, condition: Option[Expression]) - (@transient sqlContext: SQLContext) extends BinaryNode { // TODO: Override requiredChildDistribution. override def outputPartitioning: Partitioning = streamed.outputPartitioning - override def otherCopyArgs = sqlContext :: Nil - def output = left.output /** The Streamed Relation */ @@ -271,7 +452,7 @@ case class LeftSemiJoinBNL( def execute() = { val broadcastedRelation = - sqlContext.sparkContext.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq) + sparkContext.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq) streamed.execute().mapPartitions { streamedIter => val joinedRow = new JoinedRow @@ -300,8 +481,14 @@ case class LeftSemiJoinBNL( case class CartesianProduct(left: SparkPlan, right: SparkPlan) extends BinaryNode { def output = left.output ++ right.output - def execute() = left.execute().map(_.copy()).cartesian(right.execute().map(_.copy())).map { - case (l: Row, r: Row) => buildRow(l ++ r) + def execute() = { + val leftResults = left.execute().map(_.copy()) + val rightResults = right.execute().map(_.copy()) + + leftResults.cartesian(rightResults).mapPartitions { iter => + val joinedRow = new JoinedRow + iter.map(r => joinedRow(r._1, r._2)) + } } } @@ -310,14 +497,20 @@ case class CartesianProduct(left: SparkPlan, right: SparkPlan) extends BinaryNod */ @DeveloperApi case class BroadcastNestedLoopJoin( - streamed: SparkPlan, broadcast: SparkPlan, joinType: JoinType, condition: Option[Expression]) - (@transient sqlContext: SQLContext) - extends BinaryNode { + left: SparkPlan, + right: SparkPlan, + buildSide: BuildSide, + joinType: JoinType, + condition: Option[Expression]) extends BinaryNode { // TODO: Override requiredChildDistribution. - override def outputPartitioning: Partitioning = streamed.outputPartitioning + /** BuildRight means the right relation <=> the broadcast relation. */ + val (streamed, broadcast) = buildSide match { + case BuildRight => (left, right) + case BuildLeft => (right, left) + } - override def otherCopyArgs = sqlContext :: Nil + override def outputPartitioning: Partitioning = streamed.outputPartitioning override def output = { joinType match { @@ -332,11 +525,6 @@ case class BroadcastNestedLoopJoin( } } - /** The Streamed Relation */ - def left = streamed - /** The Broadcast relation */ - def right = broadcast - @transient lazy val boundCondition = InterpretedPredicate( condition @@ -345,58 +533,80 @@ case class BroadcastNestedLoopJoin( def execute() = { val broadcastedRelation = - sqlContext.sparkContext.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq) + sparkContext.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq) - val streamedPlusMatches = streamed.execute().mapPartitions { streamedIter => + /** All rows that either match both-way, or rows from streamed joined with nulls. */ + val matchesOrStreamedRowsWithNulls = streamed.execute().mapPartitions { streamedIter => val matchedRows = new ArrayBuffer[Row] // TODO: Use Spark's BitSet. - val includedBroadcastTuples = new BitSet(broadcastedRelation.value.size) + val includedBroadcastTuples = + new scala.collection.mutable.BitSet(broadcastedRelation.value.size) val joinedRow = new JoinedRow + val leftNulls = new GenericMutableRow(left.output.size) + val rightNulls = new GenericMutableRow(right.output.size) streamedIter.foreach { streamedRow => var i = 0 - var matched = false + var streamRowMatched = false while (i < broadcastedRelation.value.size) { // TODO: One bitset per partition instead of per row. val broadcastedRow = broadcastedRelation.value(i) - if (boundCondition(joinedRow(streamedRow, broadcastedRow))) { - matchedRows += buildRow(streamedRow ++ broadcastedRow) - matched = true - includedBroadcastTuples += i + buildSide match { + case BuildRight if boundCondition(joinedRow(streamedRow, broadcastedRow)) => + matchedRows += joinedRow(streamedRow, broadcastedRow).copy() + streamRowMatched = true + includedBroadcastTuples += i + case BuildLeft if boundCondition(joinedRow(broadcastedRow, streamedRow)) => + matchedRows += joinedRow(broadcastedRow, streamedRow).copy() + streamRowMatched = true + includedBroadcastTuples += i + case _ => } i += 1 } - if (!matched && (joinType == LeftOuter || joinType == FullOuter)) { - matchedRows += buildRow(streamedRow ++ Array.fill(right.output.size)(null)) + (streamRowMatched, joinType, buildSide) match { + case (false, LeftOuter | FullOuter, BuildRight) => + matchedRows += joinedRow(streamedRow, rightNulls).copy() + case (false, RightOuter | FullOuter, BuildLeft) => + matchedRows += joinedRow(leftNulls, streamedRow).copy() + case _ => } } Iterator((matchedRows, includedBroadcastTuples)) } - val includedBroadcastTuples = streamedPlusMatches.map(_._2) + val includedBroadcastTuples = matchesOrStreamedRowsWithNulls.map(_._2) val allIncludedBroadcastTuples = if (includedBroadcastTuples.count == 0) { new scala.collection.mutable.BitSet(broadcastedRelation.value.size) } else { - streamedPlusMatches.map(_._2).reduce(_ ++ _) + includedBroadcastTuples.reduce(_ ++ _) } - val rightOuterMatches: Seq[Row] = - if (joinType == RightOuter || joinType == FullOuter) { - broadcastedRelation.value.zipWithIndex.filter { - case (row, i) => !allIncludedBroadcastTuples.contains(i) - }.map { - // TODO: Use projection. - case (row, _) => buildRow(Vector.fill(left.output.size)(null) ++ row) + val leftNulls = new GenericMutableRow(left.output.size) + val rightNulls = new GenericMutableRow(right.output.size) + /** Rows from broadcasted joined with nulls. */ + val broadcastRowsWithNulls: Seq[Row] = { + val arrBuf: collection.mutable.ArrayBuffer[Row] = collection.mutable.ArrayBuffer() + var i = 0 + val rel = broadcastedRelation.value + while (i < rel.length) { + if (!allIncludedBroadcastTuples.contains(i)) { + (joinType, buildSide) match { + case (RightOuter | FullOuter, BuildRight) => arrBuf += new JoinedRow(leftNulls, rel(i)) + case (LeftOuter | FullOuter, BuildLeft) => arrBuf += new JoinedRow(rel(i), rightNulls) + case _ => + } } - } else { - Vector() + i += 1 } + arrBuf.toSeq + } // TODO: Breaks lineage. - sqlContext.sparkContext.union( - streamedPlusMatches.flatMap(_._1), sqlContext.sparkContext.makeRDD(rightOuterMatches)) + sparkContext.union( + matchesOrStreamedRowsWithNulls.flatMap(_._1), sparkContext.makeRDD(broadcastRowsWithNulls)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala new file mode 100644 index 0000000000000..b92091b560b1c --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala @@ -0,0 +1,177 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution + +import java.util.{List => JList, Map => JMap} + +import net.razorvine.pickle.{Pickler, Unpickler} +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.api.python.PythonRDD +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.{Accumulator, Logging => SparkLogging} + +import scala.collection.JavaConversions._ + +/** + * A serialized version of a Python lambda function. Suitable for use in a [[PythonRDD]]. + */ +private[spark] case class PythonUDF( + name: String, + command: Array[Byte], + envVars: JMap[String, String], + pythonIncludes: JList[String], + pythonExec: String, + accumulator: Accumulator[JList[Array[Byte]]], + dataType: DataType, + children: Seq[Expression]) extends Expression with SparkLogging { + + override def toString = s"PythonUDF#$name(${children.mkString(",")})" + + def nullable: Boolean = true + def references: Set[Attribute] = children.flatMap(_.references).toSet + + override def eval(input: Row) = sys.error("PythonUDFs can not be directly evaluated.") +} + +/** + * Extracts PythonUDFs from operators, rewriting the query plan so that the UDF can be evaluated + * alone in a batch. + * + * This has the limitation that the input to the Python UDF is not allowed include attributes from + * multiple child operators. + */ +private[spark] object ExtractPythonUdfs extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan) = plan transform { + // Skip EvaluatePython nodes. + case p: EvaluatePython => p + + case l: LogicalPlan => + // Extract any PythonUDFs from the current operator. + val udfs = l.expressions.flatMap(_.collect { case udf: PythonUDF => udf}) + if (udfs.isEmpty) { + // If there aren't any, we are done. + l + } else { + // Pick the UDF we are going to evaluate (TODO: Support evaluating multiple UDFs at a time) + // If there is more than one, we will add another evaluation operator in a subsequent pass. + val udf = udfs.head + + var evaluation: EvaluatePython = null + + // Rewrite the child that has the input required for the UDF + val newChildren = l.children.map { child => + // Check to make sure that the UDF can be evaluated with only the input of this child. + // Other cases are disallowed as they are ambiguous or would require a cartisian product. + if (udf.references.subsetOf(child.outputSet)) { + evaluation = EvaluatePython(udf, child) + evaluation + } else if (udf.references.intersect(child.outputSet).nonEmpty) { + sys.error(s"Invalid PythonUDF $udf, requires attributes from more than one child.") + } else { + child + } + } + + assert(evaluation != null, "Unable to evaluate PythonUDF. Missing input attributes.") + + // Trim away the new UDF value if it was only used for filtering or something. + logical.Project( + l.output, + l.transformExpressions { + case p: PythonUDF if p.id == udf.id => evaluation.resultAttribute + }.withNewChildren(newChildren)) + } + } +} + +/** + * :: DeveloperApi :: + * Evaluates a [[PythonUDF]], appending the result to the end of the input tuple. + */ +@DeveloperApi +case class EvaluatePython(udf: PythonUDF, child: LogicalPlan) extends logical.UnaryNode { + val resultAttribute = AttributeReference("pythonUDF", udf.dataType, nullable=true)() + + def references = Set.empty + def output = child.output :+ resultAttribute +} + +/** + * :: DeveloperApi :: + * Uses PythonRDD to evaluate a [[PythonUDF]], one partition of tuples at a time. The input + * data is cached and zipped with the result of the udf evaluation. + */ +@DeveloperApi +case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: SparkPlan) + extends SparkPlan { + def children = child :: Nil + + def execute() = { + // TODO: Clean up after ourselves? + val childResults = child.execute().map(_.copy()).cache() + + val parent = childResults.mapPartitions { iter => + val pickle = new Pickler + val currentRow = newMutableProjection(udf.children, child.output)() + iter.grouped(1000).map { inputRows => + val toBePickled = inputRows.map(currentRow(_).toArray).toArray + pickle.dumps(toBePickled) + } + } + + val pyRDD = new PythonRDD( + parent, + udf.command, + udf.envVars, + udf.pythonIncludes, + false, + udf.pythonExec, + Seq[Broadcast[Array[Byte]]](), + udf.accumulator + ).mapPartitions { iter => + val pickle = new Unpickler + iter.flatMap { pickedResult => + val unpickledBatch = pickle.loads(pickedResult) + unpickledBatch.asInstanceOf[java.util.ArrayList[Any]] + } + }.mapPartitions { iter => + val row = new GenericMutableRow(1) + iter.map { result => + row(0) = udf.dataType match { + case StringType => result.toString + case other => result + } + row: Row + } + } + + childResults.zip(pyRDD).mapPartitions { iter => + val joinedRow = new JoinedRow() + iter.map { + case (row, udfResult) => + joinedRow(row, udfResult) + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala index b48c70ee73a27..a3d2a1c7a51f8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.json +import scala.collection.Map import scala.collection.convert.Wrappers.{JMapWrapper, JListWrapper} import scala.math.BigDecimal @@ -25,30 +26,25 @@ import com.fasterxml.jackson.databind.ObjectMapper import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.types._ -import org.apache.spark.sql.execution.{ExistingRdd, SparkLogicalPlan} -import org.apache.spark.sql.Logging +import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.Logging private[sql] object JsonRDD extends Logging { + private[sql] def jsonStringToRow( + json: RDD[String], + schema: StructType): RDD[Row] = { + parseJson(json).map(parsed => asRow(parsed, schema)) + } + private[sql] def inferSchema( json: RDD[String], - samplingRatio: Double = 1.0): LogicalPlan = { + samplingRatio: Double = 1.0): StructType = { require(samplingRatio > 0, s"samplingRatio ($samplingRatio) should be greater than 0") val schemaData = if (samplingRatio > 0.99) json else json.sample(false, samplingRatio, 1) val allKeys = parseJson(schemaData).map(allKeysWithValueTypes).reduce(_ ++ _) - val baseSchema = createSchema(allKeys) - - createLogicalPlan(json, baseSchema) - } - - private def createLogicalPlan( - json: RDD[String], - baseSchema: StructType): LogicalPlan = { - val schema = nullTypeToStringType(baseSchema) - - SparkLogicalPlan(ExistingRdd(asAttributes(schema), parseJson(json).map(asRow(_, schema)))) + createSchema(allKeys) } private def createSchema(allKeys: Set[(String, DataType)]): StructType = { @@ -72,8 +68,8 @@ private[sql] object JsonRDD extends Logging { val (topLevel, structLike) = values.partition(_.size == 1) val topLevelFields = topLevel.filter { name => resolved.get(prefix ++ name).get match { - case ArrayType(StructType(Nil)) => false - case ArrayType(_) => true + case ArrayType(StructType(Nil), _) => false + case ArrayType(_, _) => true case struct: StructType => false case _ => true } @@ -87,7 +83,8 @@ private[sql] object JsonRDD extends Logging { val structType = makeStruct(nestedFields, prefix :+ name) val dataType = resolved.get(prefix :+ name).get dataType match { - case array: ArrayType => Some(StructField(name, ArrayType(structType), nullable = true)) + case array: ArrayType => + Some(StructField(name, ArrayType(structType, array.containsNull), nullable = true)) case struct: StructType => Some(StructField(name, structType, nullable = true)) // dataType is StringType means that we have resolved type conflicts involving // primitive types and complex types. So, the type of name has been relaxed to @@ -106,6 +103,22 @@ private[sql] object JsonRDD extends Logging { makeStruct(resolved.keySet.toSeq, Nil) } + private[sql] def nullTypeToStringType(struct: StructType): StructType = { + val fields = struct.fields.map { + case StructField(fieldName, dataType, nullable) => { + val newType = dataType match { + case NullType => StringType + case ArrayType(NullType, containsNull) => ArrayType(StringType, containsNull) + case struct: StructType => nullTypeToStringType(struct) + case other: DataType => other + } + StructField(fieldName, newType, nullable) + } + } + + StructType(fields) + } + /** * Returns the most general data type for two given data types. */ @@ -136,8 +149,8 @@ private[sql] object JsonRDD extends Logging { case StructField(name, _, _) => name }) } - case (ArrayType(elementType1), ArrayType(elementType2)) => - ArrayType(compatibleType(elementType1, elementType2)) + case (ArrayType(elementType1, containsNull1), ArrayType(elementType2, containsNull2)) => + ArrayType(compatibleType(elementType1, elementType2), containsNull1 || containsNull2) // TODO: We should use JsonObjectStringType to mark that values of field will be // strings and every string is a Json object. case (_, _) => StringType @@ -145,18 +158,13 @@ private[sql] object JsonRDD extends Logging { } } - private def typeOfPrimitiveValue(value: Any): DataType = { - value match { - case value: java.lang.String => StringType - case value: java.lang.Integer => IntegerType - case value: java.lang.Long => LongType + private def typeOfPrimitiveValue: PartialFunction[Any, DataType] = { + ScalaReflection.typeOfObject orElse { // Since we do not have a data type backed by BigInteger, // when we see a Java BigInteger, we use DecimalType. case value: java.math.BigInteger => DecimalType - case value: java.lang.Double => DoubleType + // DecimalType's JVMType is scala BigDecimal. case value: java.math.BigDecimal => DecimalType - case value: java.lang.Boolean => BooleanType - case null => NullType // Unexpected data type. case _ => StringType } @@ -169,12 +177,13 @@ private[sql] object JsonRDD extends Logging { * treat the element as String. */ private def typeOfArray(l: Seq[Any]): ArrayType = { + val containsNull = l.exists(v => v == null) val elements = l.flatMap(v => Option(v)) if (elements.isEmpty) { // If this JSON array is empty, we use NullType as a placeholder. // If this array is not empty in other JSON objects, we can resolve // the type after we have passed through all JSON objects. - ArrayType(NullType) + ArrayType(NullType, containsNull) } else { val elementType = elements.map { e => e match { @@ -186,7 +195,7 @@ private[sql] object JsonRDD extends Logging { } }.reduce((type1: DataType, type2: DataType) => compatibleType(type1, type2)) - ArrayType(elementType) + ArrayType(elementType, containsNull) } } @@ -213,15 +222,16 @@ private[sql] object JsonRDD extends Logging { case (key: String, array: Seq[_]) => { // The value associated with the key is an array. typeOfArray(array) match { - case ArrayType(StructType(Nil)) => { + case ArrayType(StructType(Nil), containsNull) => { // The elements of this arrays are structs. array.asInstanceOf[Seq[Map[String, Any]]].flatMap { element => allKeysWithValueTypes(element) }.map { case (k, dataType) => (s"$key.$k", dataType) - } :+ (key, ArrayType(StructType(Nil))) + } :+ (key, ArrayType(StructType(Nil), containsNull)) } - case ArrayType(elementType) => (key, ArrayType(elementType)) :: Nil + case ArrayType(elementType, containsNull) => + (key, ArrayType(elementType, containsNull)) :: Nil } } case (key: String, value) => (key, typeOfPrimitiveValue(value)) :: Nil @@ -259,8 +269,11 @@ private[sql] object JsonRDD extends Logging { // the ObjectMapper will take the last value associated with this duplicate key. // For example: for {"key": 1, "key":2}, we will get "key"->2. val mapper = new ObjectMapper() - iter.map(record => mapper.readValue(record, classOf[java.util.Map[String, Any]])) - }).map(scalafy).map(_.asInstanceOf[Map[String, Any]]) + iter.map { record => + val parsed = scalafy(mapper.readValue(record, classOf[java.util.Map[String, Any]])) + parsed.asInstanceOf[Map[String, Any]] + } + }) } private def toLong(value: Any): Long = { @@ -331,7 +344,7 @@ private[sql] object JsonRDD extends Logging { null } else { desiredType match { - case ArrayType(elementType) => + case ArrayType(elementType, _) => value.asInstanceOf[Seq[Any]].map(enforceCorrectType(_, elementType)) case StringType => toString(value) case IntegerType => value.asInstanceOf[IntegerType.JvmType] @@ -345,6 +358,7 @@ private[sql] object JsonRDD extends Logging { } private def asRow(json: Map[String,Any], schema: StructType): Row = { + // TODO: Reuse the row instead of creating a new one for every record. val row = new GenericMutableRow(schema.fields.length) schema.fields.zipWithIndex.foreach { // StructType @@ -353,7 +367,7 @@ private[sql] object JsonRDD extends Logging { v => asRow(v.asInstanceOf[Map[String, Any]], fields)).orNull) // ArrayType(StructType) - case (StructField(name, ArrayType(structType: StructType), _), i) => + case (StructField(name, ArrayType(structType: StructType, _), _), i) => row.update(i, json.get(name).flatMap(v => Option(v)).map( v => v.asInstanceOf[Seq[Any]].map( @@ -367,32 +381,4 @@ private[sql] object JsonRDD extends Logging { row } - - private def nullTypeToStringType(struct: StructType): StructType = { - val fields = struct.fields.map { - case StructField(fieldName, dataType, nullable) => { - val newType = dataType match { - case NullType => StringType - case ArrayType(NullType) => ArrayType(StringType) - case struct: StructType => nullTypeToStringType(struct) - case other: DataType => other - } - StructField(fieldName, newType, nullable) - } - } - - StructType(fields) - } - - private def asAttributes(struct: StructType): Seq[AttributeReference] = { - struct.fields.map(f => AttributeReference(f.name, f.dataType, nullable = true)()) - } - - private def asStruct(attributes: Seq[AttributeReference]): StructType = { - val fields = attributes.map { - case AttributeReference(name, dataType, nullable) => StructField(name, dataType, nullable) - } - - StructType(fields) - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/package.scala new file mode 100644 index 0000000000000..f513eae9c2d13 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/package.scala @@ -0,0 +1,407 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import org.apache.spark.annotation.DeveloperApi + +/** + * Allows the execution of relational queries, including those expressed in SQL using Spark. + * + * @groupname dataType Data types + * @groupdesc Spark SQL data types. + * @groupprio dataType -3 + * @groupname field Field + * @groupprio field -2 + * @groupname row Row + * @groupprio row -1 + */ +package object sql { + + /** + * :: DeveloperApi :: + * + * Represents one row of output from a relational operator. + * @group row + */ + @DeveloperApi + type Row = catalyst.expressions.Row + + /** + * :: DeveloperApi :: + * + * A [[Row]] object can be constructed by providing field values. Example: + * {{{ + * import org.apache.spark.sql._ + * + * // Create a Row from values. + * Row(value1, value2, value3, ...) + * // Create a Row from a Seq of values. + * Row.fromSeq(Seq(value1, value2, ...)) + * }}} + * + * A value of a row can be accessed through both generic access by ordinal, + * which will incur boxing overhead for primitives, as well as native primitive access. + * An example of generic access by ordinal: + * {{{ + * import org.apache.spark.sql._ + * + * val row = Row(1, true, "a string", null) + * // row: Row = [1,true,a string,null] + * val firstValue = row(0) + * // firstValue: Any = 1 + * val fourthValue = row(3) + * // fourthValue: Any = null + * }}} + * + * For native primitive access, it is invalid to use the native primitive interface to retrieve + * a value that is null, instead a user must check `isNullAt` before attempting to retrieve a + * value that might be null. + * An example of native primitive access: + * {{{ + * // using the row from the previous example. + * val firstValue = row.getInt(0) + * // firstValue: Int = 1 + * val isNull = row.isNullAt(3) + * // isNull: Boolean = true + * }}} + * + * Interfaces related to native primitive access are: + * + * `isNullAt(i: Int): Boolean` + * + * `getInt(i: Int): Int` + * + * `getLong(i: Int): Long` + * + * `getDouble(i: Int): Double` + * + * `getFloat(i: Int): Float` + * + * `getBoolean(i: Int): Boolean` + * + * `getShort(i: Int): Short` + * + * `getByte(i: Int): Byte` + * + * `getString(i: Int): String` + * + * Fields in a [[Row]] object can be extracted in a pattern match. Example: + * {{{ + * import org.apache.spark.sql._ + * + * val pairs = sql("SELECT key, value FROM src").rdd.map { + * case Row(key: Int, value: String) => + * key -> value + * } + * }}} + * + * @group row + */ + @DeveloperApi + val Row = catalyst.expressions.Row + + /** + * :: DeveloperApi :: + * + * The base type of all Spark SQL data types. + * + * @group dataType + */ + @DeveloperApi + type DataType = catalyst.types.DataType + + /** + * :: DeveloperApi :: + * + * The data type representing `String` values + * + * @group dataType + */ + @DeveloperApi + val StringType = catalyst.types.StringType + + /** + * :: DeveloperApi :: + * + * The data type representing `Array[Byte]` values. + * + * @group dataType + */ + @DeveloperApi + val BinaryType = catalyst.types.BinaryType + + /** + * :: DeveloperApi :: + * + * The data type representing `Boolean` values. + * + *@group dataType + */ + @DeveloperApi + val BooleanType = catalyst.types.BooleanType + + /** + * :: DeveloperApi :: + * + * The data type representing `java.sql.Timestamp` values. + * + * @group dataType + */ + @DeveloperApi + val TimestampType = catalyst.types.TimestampType + + /** + * :: DeveloperApi :: + * + * The data type representing `scala.math.BigDecimal` values. + * + * @group dataType + */ + @DeveloperApi + val DecimalType = catalyst.types.DecimalType + + /** + * :: DeveloperApi :: + * + * The data type representing `Double` values. + * + * @group dataType + */ + @DeveloperApi + val DoubleType = catalyst.types.DoubleType + + /** + * :: DeveloperApi :: + * + * The data type representing `Float` values. + * + * @group dataType + */ + @DeveloperApi + val FloatType = catalyst.types.FloatType + + /** + * :: DeveloperApi :: + * + * The data type representing `Byte` values. + * + * @group dataType + */ + @DeveloperApi + val ByteType = catalyst.types.ByteType + + /** + * :: DeveloperApi :: + * + * The data type representing `Int` values. + * + * @group dataType + */ + @DeveloperApi + val IntegerType = catalyst.types.IntegerType + + /** + * :: DeveloperApi :: + * + * The data type representing `Long` values. + * + * @group dataType + */ + @DeveloperApi + val LongType = catalyst.types.LongType + + /** + * :: DeveloperApi :: + * + * The data type representing `Short` values. + * + * @group dataType + */ + @DeveloperApi + val ShortType = catalyst.types.ShortType + + /** + * :: DeveloperApi :: + * + * The data type for collections of multiple values. + * Internally these are represented as columns that contain a ``scala.collection.Seq``. + * + * An [[ArrayType]] object comprises two fields, `elementType: [[DataType]]` and + * `containsNull: Boolean`. The field of `elementType` is used to specify the type of + * array elements. The field of `containsNull` is used to specify if the array has `null` values. + * + * @group dataType + */ + @DeveloperApi + type ArrayType = catalyst.types.ArrayType + + /** + * :: DeveloperApi :: + * + * An [[ArrayType]] object can be constructed with two ways, + * {{{ + * ArrayType(elementType: DataType, containsNull: Boolean) + * }}} and + * {{{ + * ArrayType(elementType: DataType) + * }}} + * For `ArrayType(elementType)`, the field of `containsNull` is set to `false`. + * + * @group dataType + */ + @DeveloperApi + val ArrayType = catalyst.types.ArrayType + + /** + * :: DeveloperApi :: + * + * The data type representing `Map`s. A [[MapType]] object comprises three fields, + * `keyType: [[DataType]]`, `valueType: [[DataType]]` and `valueContainsNull: Boolean`. + * The field of `keyType` is used to specify the type of keys in the map. + * The field of `valueType` is used to specify the type of values in the map. + * The field of `valueContainsNull` is used to specify if values of this map has `null` values. + * For values of a MapType column, keys are not allowed to have `null` values. + * + * @group dataType + */ + @DeveloperApi + type MapType = catalyst.types.MapType + + /** + * :: DeveloperApi :: + * + * A [[MapType]] object can be constructed with two ways, + * {{{ + * MapType(keyType: DataType, valueType: DataType, valueContainsNull: Boolean) + * }}} and + * {{{ + * MapType(keyType: DataType, valueType: DataType) + * }}} + * For `MapType(keyType: DataType, valueType: DataType)`, + * the field of `valueContainsNull` is set to `true`. + * + * @group dataType + */ + @DeveloperApi + val MapType = catalyst.types.MapType + + /** + * :: DeveloperApi :: + * + * The data type representing [[Row]]s. + * A [[StructType]] object comprises a [[Seq]] of [[StructField]]s. + * + * @group dataType + */ + @DeveloperApi + type StructType = catalyst.types.StructType + + /** + * :: DeveloperApi :: + * + * A [[StructType]] object can be constructed by + * {{{ + * StructType(fields: Seq[StructField]) + * }}} + * For a [[StructType]] object, one or multiple [[StructField]]s can be extracted by names. + * If multiple [[StructField]]s are extracted, a [[StructType]] object will be returned. + * If a provided name does not have a matching field, it will be ignored. For the case + * of extracting a single StructField, a `null` will be returned. + * Example: + * {{{ + * import org.apache.spark.sql._ + * + * val struct = + * StructType( + * StructField("a", IntegerType, true) :: + * StructField("b", LongType, false) :: + * StructField("c", BooleanType, false) :: Nil) + * + * // Extract a single StructField. + * val singleField = struct("b") + * // singleField: StructField = StructField(b,LongType,false) + * + * // This struct does not have a field called "d". null will be returned. + * val nonExisting = struct("d") + * // nonExisting: StructField = null + * + * // Extract multiple StructFields. Field names are provided in a set. + * // A StructType object will be returned. + * val twoFields = struct(Set("b", "c")) + * // twoFields: StructType = + * // StructType(List(StructField(b,LongType,false), StructField(c,BooleanType,false))) + * + * // Those names do not have matching fields will be ignored. + * // For the case shown below, "d" will be ignored and + * // it is treated as struct(Set("b", "c")). + * val ignoreNonExisting = struct(Set("b", "c", "d")) + * // ignoreNonExisting: StructType = + * // StructType(List(StructField(b,LongType,false), StructField(c,BooleanType,false))) + * }}} + * + * A [[Row]] object is used as a value of the StructType. + * Example: + * {{{ + * import org.apache.spark.sql._ + * + * val innerStruct = + * StructType( + * StructField("f1", IntegerType, true) :: + * StructField("f2", LongType, false) :: + * StructField("f3", BooleanType, false) :: Nil) + * + * val struct = StructType( + * StructField("a", innerStruct, true) :: Nil) + * + * // Create a Row with the schema defined by struct + * val row = Row(Row(1, 2, true)) + * // row: Row = [[1,2,true]] + * }}} + * + * @group dataType + */ + @DeveloperApi + val StructType = catalyst.types.StructType + + /** + * :: DeveloperApi :: + * + * A [[StructField]] object represents a field in a [[StructType]] object. + * A [[StructField]] object comprises three fields, `name: [[String]]`, `dataType: [[DataType]]`, + * and `nullable: Boolean`. The field of `name` is the name of a `StructField`. The field of + * `dataType` specifies the data type of a `StructField`. + * The field of `nullable` specifies if values of a `StructField` can contain `null` values. + * + * @group field + */ + @DeveloperApi + type StructField = catalyst.types.StructField + + /** + * :: DeveloperApi :: + * + * A [[StructField]] object can be constructed by + * {{{ + * StructField(name: String, dataType: DataType, nullable: Boolean) + * }}} + * + * @group dataType + */ + @DeveloperApi + val StructField = catalyst.types.StructField +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala index de8fe2dae38f6..0a3b59cbc233a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala @@ -75,21 +75,21 @@ private[sql] object CatalystConverter { val fieldType: DataType = field.dataType fieldType match { // For native JVM types we use a converter with native arrays - case ArrayType(elementType: NativeType) => { + case ArrayType(elementType: NativeType, false) => { new CatalystNativeArrayConverter(elementType, fieldIndex, parent) } // This is for other types of arrays, including those with nested fields - case ArrayType(elementType: DataType) => { + case ArrayType(elementType: DataType, false) => { new CatalystArrayConverter(elementType, fieldIndex, parent) } case StructType(fields: Seq[StructField]) => { new CatalystStructConverter(fields.toArray, fieldIndex, parent) } - case MapType(keyType: DataType, valueType: DataType) => { + case MapType(keyType: DataType, valueType: DataType, valueContainsNull: Boolean) => { new CatalystMapConverter( Array( new FieldType(MAP_KEY_SCHEMA_NAME, keyType, false), - new FieldType(MAP_VALUE_SCHEMA_NAME, valueType, true)), + new FieldType(MAP_VALUE_SCHEMA_NAME, valueType, valueContainsNull)), fieldIndex, parent) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala index 9c4771d1a9846..b3bae5db0edbc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala @@ -27,6 +27,7 @@ import parquet.hadoop.ParquetOutputFormat import parquet.hadoop.metadata.CompressionCodecName import parquet.schema.MessageType +import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, UnresolvedException} import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, LeafNode} @@ -45,7 +46,9 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, LeafNode} */ private[sql] case class ParquetRelation( path: String, - @transient conf: Option[Configuration] = None) extends LeafNode with MultiInstanceRelation { + @transient conf: Option[Configuration], + @transient sqlContext: SQLContext) + extends LeafNode with MultiInstanceRelation { self: Product => @@ -59,7 +62,7 @@ private[sql] case class ParquetRelation( /** Attributes */ override val output = ParquetTypesConverter.readSchemaFromFile(new Path(path), conf) - override def newInstance = ParquetRelation(path).asInstanceOf[this.type] + override def newInstance = ParquetRelation(path, conf, sqlContext).asInstanceOf[this.type] // Equals must also take into account the output attributes so that we can distinguish between // different instances of the same relation, @@ -68,6 +71,9 @@ private[sql] case class ParquetRelation( p.path == path && p.output == output case _ => false } + + // TODO: Use data from the footers. + override lazy val statistics = Statistics(sizeInBytes = sqlContext.defaultSizeInBytes) } private[sql] object ParquetRelation { @@ -104,13 +110,14 @@ private[sql] object ParquetRelation { */ def create(pathString: String, child: LogicalPlan, - conf: Configuration): ParquetRelation = { + conf: Configuration, + sqlContext: SQLContext): ParquetRelation = { if (!child.resolved) { throw new UnresolvedException[LogicalPlan]( child, "Attempt to create Parquet table from unresolved child (when schema is not available)") } - createEmpty(pathString, child.output, false, conf) + createEmpty(pathString, child.output, false, conf, sqlContext) } /** @@ -125,14 +132,15 @@ private[sql] object ParquetRelation { def createEmpty(pathString: String, attributes: Seq[Attribute], allowExisting: Boolean, - conf: Configuration): ParquetRelation = { + conf: Configuration, + sqlContext: SQLContext): ParquetRelation = { val path = checkPath(pathString, allowExisting, conf) if (conf.get(ParquetOutputFormat.COMPRESSION) == null) { conf.set(ParquetOutputFormat.COMPRESSION, ParquetRelation.defaultCompression.name()) } ParquetRelation.enableLogForwarding() ParquetTypesConverter.writeMetaData(attributes, path, conf) - new ParquetRelation(path.toString, Some(conf)) { + new ParquetRelation(path.toString, Some(conf), sqlContext) { override val output = attributes } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala index ea74320d06c86..759a2a586b926 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala @@ -51,14 +51,20 @@ import org.apache.spark.{Logging, SerializableWritable, TaskContext} * [[org.apache.spark.sql.parquet.ParquetRelation]] as a ``RDD[Row]``. */ case class ParquetTableScan( - // note: output cannot be transient, see - // https://issues.apache.org/jira/browse/SPARK-1367 - output: Seq[Attribute], + attributes: Seq[Attribute], relation: ParquetRelation, - columnPruningPred: Seq[Expression])( - @transient val sqlContext: SQLContext) + columnPruningPred: Seq[Expression]) extends LeafNode { + // The resolution of Parquet attributes is case sensitive, so we resolve the original attributes + // by exprId. note: output cannot be transient, see + // https://issues.apache.org/jira/browse/SPARK-1367 + val output = attributes.map { a => + relation.output + .find(o => o.exprId == a.exprId) + .getOrElse(sys.error(s"Invalid parquet attribute $a in ${relation.output.mkString(",")}")) + } + override def execute(): RDD[Row] = { val sc = sqlContext.sparkContext val job = new Job(sc.hadoopConfiguration) @@ -99,8 +105,6 @@ case class ParquetTableScan( .filter(_ != null) // Parquet's record filters may produce null values } - override def otherCopyArgs = sqlContext :: Nil - /** * Applies a (candidate) projection. * @@ -110,10 +114,9 @@ case class ParquetTableScan( def pruneColumns(prunedAttributes: Seq[Attribute]): ParquetTableScan = { val success = validateProjection(prunedAttributes) if (success) { - ParquetTableScan(prunedAttributes, relation, columnPruningPred)(sqlContext) + ParquetTableScan(prunedAttributes, relation, columnPruningPred) } else { sys.error("Warning: Could not validate Parquet schema projection in pruneColumns") - this } } @@ -150,8 +153,7 @@ case class ParquetTableScan( case class InsertIntoParquetTable( relation: ParquetRelation, child: SparkPlan, - overwrite: Boolean = false)( - @transient val sqlContext: SQLContext) + overwrite: Boolean = false) extends UnaryNode with SparkHadoopMapReduceUtil { /** @@ -171,7 +173,7 @@ case class InsertIntoParquetTable( val writeSupport = if (child.output.map(_.dataType).forall(_.isPrimitive)) { - logger.debug("Initializing MutableRowWriteSupport") + log.debug("Initializing MutableRowWriteSupport") classOf[org.apache.spark.sql.parquet.MutableRowWriteSupport] } else { classOf[org.apache.spark.sql.parquet.RowWriteSupport] @@ -203,8 +205,6 @@ case class InsertIntoParquetTable( override def output = child.output - override def otherCopyArgs = sqlContext :: Nil - /** * Stores the given Row RDD as a Hadoop file. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala index 39294a3f4bf5a..6d4ce32ac5bfa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala @@ -172,10 +172,10 @@ private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging { private[parquet] def writeValue(schema: DataType, value: Any): Unit = { if (value != null) { schema match { - case t @ ArrayType(_) => writeArray( + case t @ ArrayType(_, false) => writeArray( t, value.asInstanceOf[CatalystConverter.ArrayScalaType[_]]) - case t @ MapType(_, _) => writeMap( + case t @ MapType(_, _, _) => writeMap( t, value.asInstanceOf[CatalystConverter.MapScalaType[_, _]]) case t @ StructType(_) => writeStruct( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala index d4599da711254..837ea7695dbb3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala @@ -22,6 +22,7 @@ import java.io.File import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} import org.apache.hadoop.mapreduce.Job +import org.apache.spark.sql.test.TestSQLContext import parquet.example.data.{GroupWriter, Group} import parquet.example.data.simple.SimpleGroup @@ -103,7 +104,7 @@ private[sql] object ParquetTestData { val testDir = Utils.createTempDir() val testFilterDir = Utils.createTempDir() - lazy val testData = new ParquetRelation(testDir.toURI.toString) + lazy val testData = new ParquetRelation(testDir.toURI.toString, None, TestSQLContext) val testNestedSchema1 = // based on blogpost example, source: @@ -202,8 +203,10 @@ private[sql] object ParquetTestData { val testNestedDir3 = Utils.createTempDir() val testNestedDir4 = Utils.createTempDir() - lazy val testNestedData1 = new ParquetRelation(testNestedDir1.toURI.toString) - lazy val testNestedData2 = new ParquetRelation(testNestedDir2.toURI.toString) + lazy val testNestedData1 = + new ParquetRelation(testNestedDir1.toURI.toString, None, TestSQLContext) + lazy val testNestedData2 = + new ParquetRelation(testNestedDir2.toURI.toString, None, TestSQLContext) def writeFile() = { testDir.delete() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala index 58370b955a5ec..aaef1a1d474fe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala @@ -116,7 +116,7 @@ private[parquet] object ParquetTypesConverter extends Logging { case ParquetOriginalType.LIST => { // TODO: check enums! assert(groupType.getFieldCount == 1) val field = groupType.getFields.apply(0) - new ArrayType(toDataType(field)) + ArrayType(toDataType(field), containsNull = false) } case ParquetOriginalType.MAP => { assert( @@ -130,7 +130,9 @@ private[parquet] object ParquetTypesConverter extends Logging { assert(keyValueGroup.getFields.apply(0).getRepetition == Repetition.REQUIRED) val valueType = toDataType(keyValueGroup.getFields.apply(1)) assert(keyValueGroup.getFields.apply(1).getRepetition == Repetition.REQUIRED) - new MapType(keyType, valueType) + // TODO: set valueContainsNull explicitly instead of assuming valueContainsNull is true + // at here. + MapType(keyType, valueType) } case _ => { // Note: the order of these checks is important! @@ -140,10 +142,12 @@ private[parquet] object ParquetTypesConverter extends Logging { assert(keyValueGroup.getFields.apply(0).getRepetition == Repetition.REQUIRED) val valueType = toDataType(keyValueGroup.getFields.apply(1)) assert(keyValueGroup.getFields.apply(1).getRepetition == Repetition.REQUIRED) - new MapType(keyType, valueType) + // TODO: set valueContainsNull explicitly instead of assuming valueContainsNull is true + // at here. + MapType(keyType, valueType) } else if (correspondsToArray(groupType)) { // ArrayType val elementType = toDataType(groupType.getFields.apply(0)) - new ArrayType(elementType) + ArrayType(elementType, containsNull = false) } else { // everything else: StructType val fields = groupType .getFields @@ -151,7 +155,7 @@ private[parquet] object ParquetTypesConverter extends Logging { ptype.getName, toDataType(ptype), ptype.getRepetition != Repetition.REQUIRED)) - new StructType(fields) + StructType(fields) } } } @@ -234,7 +238,7 @@ private[parquet] object ParquetTypesConverter extends Logging { new ParquetPrimitiveType(repetition, primitiveType, name, originalType.orNull) }.getOrElse { ctype match { - case ArrayType(elementType) => { + case ArrayType(elementType, false) => { val parquetElementType = fromDataType( elementType, CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, @@ -248,7 +252,7 @@ private[parquet] object ParquetTypesConverter extends Logging { } new ParquetGroupType(repetition, name, fields) } - case MapType(keyType, valueType) => { + case MapType(keyType, valueType, _) => { val parquetKeyType = fromDataType( keyType, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala b/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala new file mode 100644 index 0000000000000..77353f4eb0227 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types.util + +import org.apache.spark.sql._ +import org.apache.spark.sql.api.java.{DataType => JDataType, StructField => JStructField} + +import scala.collection.JavaConverters._ + +protected[sql] object DataTypeConversions { + + /** + * Returns the equivalent StructField in Scala for the given StructField in Java. + */ + def asJavaStructField(scalaStructField: StructField): JStructField = { + JDataType.createStructField( + scalaStructField.name, + asJavaDataType(scalaStructField.dataType), + scalaStructField.nullable) + } + + /** + * Returns the equivalent DataType in Java for the given DataType in Scala. + */ + def asJavaDataType(scalaDataType: DataType): JDataType = scalaDataType match { + case StringType => JDataType.StringType + case BinaryType => JDataType.BinaryType + case BooleanType => JDataType.BooleanType + case TimestampType => JDataType.TimestampType + case DecimalType => JDataType.DecimalType + case DoubleType => JDataType.DoubleType + case FloatType => JDataType.FloatType + case ByteType => JDataType.ByteType + case IntegerType => JDataType.IntegerType + case LongType => JDataType.LongType + case ShortType => JDataType.ShortType + + case arrayType: ArrayType => JDataType.createArrayType( + asJavaDataType(arrayType.elementType), arrayType.containsNull) + case mapType: MapType => JDataType.createMapType( + asJavaDataType(mapType.keyType), + asJavaDataType(mapType.valueType), + mapType.valueContainsNull) + case structType: StructType => JDataType.createStructType( + structType.fields.map(asJavaStructField).asJava) + } + + /** + * Returns the equivalent StructField in Scala for the given StructField in Java. + */ + def asScalaStructField(javaStructField: JStructField): StructField = { + StructField( + javaStructField.getName, + asScalaDataType(javaStructField.getDataType), + javaStructField.isNullable) + } + + /** + * Returns the equivalent DataType in Scala for the given DataType in Java. + */ + def asScalaDataType(javaDataType: JDataType): DataType = javaDataType match { + case stringType: org.apache.spark.sql.api.java.StringType => + StringType + case binaryType: org.apache.spark.sql.api.java.BinaryType => + BinaryType + case booleanType: org.apache.spark.sql.api.java.BooleanType => + BooleanType + case timestampType: org.apache.spark.sql.api.java.TimestampType => + TimestampType + case decimalType: org.apache.spark.sql.api.java.DecimalType => + DecimalType + case doubleType: org.apache.spark.sql.api.java.DoubleType => + DoubleType + case floatType: org.apache.spark.sql.api.java.FloatType => + FloatType + case byteType: org.apache.spark.sql.api.java.ByteType => + ByteType + case integerType: org.apache.spark.sql.api.java.IntegerType => + IntegerType + case longType: org.apache.spark.sql.api.java.LongType => + LongType + case shortType: org.apache.spark.sql.api.java.ShortType => + ShortType + + case arrayType: org.apache.spark.sql.api.java.ArrayType => + ArrayType(asScalaDataType(arrayType.getElementType), arrayType.isContainsNull) + case mapType: org.apache.spark.sql.api.java.MapType => + MapType( + asScalaDataType(mapType.getKeyType), + asScalaDataType(mapType.getValueType), + mapType.isValueContainsNull) + case structType: org.apache.spark.sql.api.java.StructType => + StructType(structType.getFields.map(asScalaStructField)) + } +} diff --git a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaAPISuite.java b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaAPISuite.java new file mode 100644 index 0000000000000..a9a11285def54 --- /dev/null +++ b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaAPISuite.java @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +import java.io.Serializable; + +import org.apache.spark.sql.api.java.UDF1; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.junit.runners.Suite; +import org.junit.runner.RunWith; + +import org.apache.spark.api.java.JavaSparkContext; + +// The test suite itself is Serializable so that anonymous Function implementations can be +// serialized, as an alternative to converting these anonymous classes to static inner classes; +// see http://stackoverflow.com/questions/758570/. +public class JavaAPISuite implements Serializable { + private transient JavaSparkContext sc; + private transient JavaSQLContext sqlContext; + + @Before + public void setUp() { + sc = new JavaSparkContext("local", "JavaAPISuite"); + sqlContext = new JavaSQLContext(sc); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + } + + @SuppressWarnings("unchecked") + @Test + public void udf1Test() { + // With Java 8 lambdas: + // sqlContext.registerFunction( + // "stringLengthTest", (String str) -> str.length(), DataType.IntegerType); + + sqlContext.registerFunction("stringLengthTest", new UDF1() { + @Override + public Integer call(String str) throws Exception { + return str.length(); + } + }, DataType.IntegerType); + + // TODO: Why do we need this cast? + Row result = (Row) sqlContext.sql("SELECT stringLengthTest('test')").first(); + assert(result.getInt(0) == 4); + } + + @SuppressWarnings("unchecked") + @Test + public void udf2Test() { + // With Java 8 lambdas: + // sqlContext.registerFunction( + // "stringLengthTest", + // (String str1, String str2) -> str1.length() + str2.length, + // DataType.IntegerType); + + sqlContext.registerFunction("stringLengthTest", new UDF2() { + @Override + public Integer call(String str1, String str2) throws Exception { + return str1.length() + str2.length(); + } + }, DataType.IntegerType); + + // TODO: Why do we need this cast? + Row result = (Row) sqlContext.sql("SELECT stringLengthTest('test', 'test2')").first(); + assert(result.getInt(0) == 9); + } +} diff --git a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java new file mode 100644 index 0000000000000..33e5020bc636a --- /dev/null +++ b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java @@ -0,0 +1,163 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +import java.io.Serializable; +import java.math.BigDecimal; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.Function; + +// The test suite itself is Serializable so that anonymous Function implementations can be +// serialized, as an alternative to converting these anonymous classes to static inner classes; +// see http://stackoverflow.com/questions/758570/. +public class JavaApplySchemaSuite implements Serializable { + private transient JavaSparkContext javaCtx; + private transient JavaSQLContext javaSqlCtx; + + @Before + public void setUp() { + javaCtx = new JavaSparkContext("local", "JavaApplySchemaSuite"); + javaSqlCtx = new JavaSQLContext(javaCtx); + } + + @After + public void tearDown() { + javaCtx.stop(); + javaCtx = null; + javaSqlCtx = null; + } + + public static class Person implements Serializable { + private String name; + private int age; + + public String getName() { + return name; + } + + public void setName(String name) { + this.name = name; + } + + public int getAge() { + return age; + } + + public void setAge(int age) { + this.age = age; + } + } + + @Test + public void applySchema() { + List personList = new ArrayList(2); + Person person1 = new Person(); + person1.setName("Michael"); + person1.setAge(29); + personList.add(person1); + Person person2 = new Person(); + person2.setName("Yin"); + person2.setAge(28); + personList.add(person2); + + JavaRDD rowRDD = javaCtx.parallelize(personList).map( + new Function() { + public Row call(Person person) throws Exception { + return Row.create(person.getName(), person.getAge()); + } + }); + + List fields = new ArrayList(2); + fields.add(DataType.createStructField("name", DataType.StringType, false)); + fields.add(DataType.createStructField("age", DataType.IntegerType, false)); + StructType schema = DataType.createStructType(fields); + + JavaSchemaRDD schemaRDD = javaSqlCtx.applySchema(rowRDD, schema); + schemaRDD.registerTempTable("people"); + List actual = javaSqlCtx.sql("SELECT * FROM people").collect(); + + List expected = new ArrayList(2); + expected.add(Row.create("Michael", 29)); + expected.add(Row.create("Yin", 28)); + + Assert.assertEquals(expected, actual); + } + + @Test + public void applySchemaToJSON() { + JavaRDD jsonRDD = javaCtx.parallelize(Arrays.asList( + "{\"string\":\"this is a simple string.\", \"integer\":10, \"long\":21474836470, " + + "\"bigInteger\":92233720368547758070, \"double\":1.7976931348623157E308, " + + "\"boolean\":true, \"null\":null}", + "{\"string\":\"this is another simple string.\", \"integer\":11, \"long\":21474836469, " + + "\"bigInteger\":92233720368547758069, \"double\":1.7976931348623157E305, " + + "\"boolean\":false, \"null\":null}")); + List fields = new ArrayList(7); + fields.add(DataType.createStructField("bigInteger", DataType.DecimalType, true)); + fields.add(DataType.createStructField("boolean", DataType.BooleanType, true)); + fields.add(DataType.createStructField("double", DataType.DoubleType, true)); + fields.add(DataType.createStructField("integer", DataType.IntegerType, true)); + fields.add(DataType.createStructField("long", DataType.LongType, true)); + fields.add(DataType.createStructField("null", DataType.StringType, true)); + fields.add(DataType.createStructField("string", DataType.StringType, true)); + StructType expectedSchema = DataType.createStructType(fields); + List expectedResult = new ArrayList(2); + expectedResult.add( + Row.create( + new BigDecimal("92233720368547758070"), + true, + 1.7976931348623157E308, + 10, + 21474836470L, + null, + "this is a simple string.")); + expectedResult.add( + Row.create( + new BigDecimal("92233720368547758069"), + false, + 1.7976931348623157E305, + 11, + 21474836469L, + null, + "this is another simple string.")); + + JavaSchemaRDD schemaRDD1 = javaSqlCtx.jsonRDD(jsonRDD); + StructType actualSchema1 = schemaRDD1.schema(); + Assert.assertEquals(expectedSchema, actualSchema1); + schemaRDD1.registerTempTable("jsonTable1"); + List actual1 = javaSqlCtx.sql("select * from jsonTable1").collect(); + Assert.assertEquals(expectedResult, actual1); + + JavaSchemaRDD schemaRDD2 = javaSqlCtx.jsonRDD(jsonRDD, expectedSchema); + StructType actualSchema2 = schemaRDD2.schema(); + Assert.assertEquals(expectedSchema, actualSchema2); + schemaRDD1.registerTempTable("jsonTable2"); + List actual2 = javaSqlCtx.sql("select * from jsonTable2").collect(); + Assert.assertEquals(expectedResult, actual2); + } +} diff --git a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaRowSuite.java b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaRowSuite.java new file mode 100644 index 0000000000000..52d07b5425cc3 --- /dev/null +++ b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaRowSuite.java @@ -0,0 +1,170 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +import java.math.BigDecimal; +import java.sql.Timestamp; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +public class JavaRowSuite { + private byte byteValue; + private short shortValue; + private int intValue; + private long longValue; + private float floatValue; + private double doubleValue; + private BigDecimal decimalValue; + private boolean booleanValue; + private String stringValue; + private byte[] binaryValue; + private Timestamp timestampValue; + + @Before + public void setUp() { + byteValue = (byte)127; + shortValue = (short)32767; + intValue = 2147483647; + longValue = 9223372036854775807L; + floatValue = (float)3.4028235E38; + doubleValue = 1.7976931348623157E308; + decimalValue = new BigDecimal("1.7976931348623157E328"); + booleanValue = true; + stringValue = "this is a string"; + binaryValue = stringValue.getBytes(); + timestampValue = Timestamp.valueOf("2014-06-30 09:20:00.0"); + } + + @Test + public void constructSimpleRow() { + Row simpleRow = Row.create( + byteValue, // ByteType + new Byte(byteValue), + shortValue, // ShortType + new Short(shortValue), + intValue, // IntegerType + new Integer(intValue), + longValue, // LongType + new Long(longValue), + floatValue, // FloatType + new Float(floatValue), + doubleValue, // DoubleType + new Double(doubleValue), + decimalValue, // DecimalType + booleanValue, // BooleanType + new Boolean(booleanValue), + stringValue, // StringType + binaryValue, // BinaryType + timestampValue, // TimestampType + null // null + ); + + Assert.assertEquals(byteValue, simpleRow.getByte(0)); + Assert.assertEquals(byteValue, simpleRow.get(0)); + Assert.assertEquals(byteValue, simpleRow.getByte(1)); + Assert.assertEquals(byteValue, simpleRow.get(1)); + Assert.assertEquals(shortValue, simpleRow.getShort(2)); + Assert.assertEquals(shortValue, simpleRow.get(2)); + Assert.assertEquals(shortValue, simpleRow.getShort(3)); + Assert.assertEquals(shortValue, simpleRow.get(3)); + Assert.assertEquals(intValue, simpleRow.getInt(4)); + Assert.assertEquals(intValue, simpleRow.get(4)); + Assert.assertEquals(intValue, simpleRow.getInt(5)); + Assert.assertEquals(intValue, simpleRow.get(5)); + Assert.assertEquals(longValue, simpleRow.getLong(6)); + Assert.assertEquals(longValue, simpleRow.get(6)); + Assert.assertEquals(longValue, simpleRow.getLong(7)); + Assert.assertEquals(longValue, simpleRow.get(7)); + // When we create the row, we do not do any conversion + // for a float/double value, so we just set the delta to 0. + Assert.assertEquals(floatValue, simpleRow.getFloat(8), 0); + Assert.assertEquals(floatValue, simpleRow.get(8)); + Assert.assertEquals(floatValue, simpleRow.getFloat(9), 0); + Assert.assertEquals(floatValue, simpleRow.get(9)); + Assert.assertEquals(doubleValue, simpleRow.getDouble(10), 0); + Assert.assertEquals(doubleValue, simpleRow.get(10)); + Assert.assertEquals(doubleValue, simpleRow.getDouble(11), 0); + Assert.assertEquals(doubleValue, simpleRow.get(11)); + Assert.assertEquals(decimalValue, simpleRow.get(12)); + Assert.assertEquals(booleanValue, simpleRow.getBoolean(13)); + Assert.assertEquals(booleanValue, simpleRow.get(13)); + Assert.assertEquals(booleanValue, simpleRow.getBoolean(14)); + Assert.assertEquals(booleanValue, simpleRow.get(14)); + Assert.assertEquals(stringValue, simpleRow.getString(15)); + Assert.assertEquals(stringValue, simpleRow.get(15)); + Assert.assertEquals(binaryValue, simpleRow.get(16)); + Assert.assertEquals(timestampValue, simpleRow.get(17)); + Assert.assertEquals(true, simpleRow.isNullAt(18)); + Assert.assertEquals(null, simpleRow.get(18)); + } + + @Test + public void constructComplexRow() { + // Simple array + List simpleStringArray = Arrays.asList( + stringValue + " (1)", stringValue + " (2)", stringValue + "(3)"); + + // Simple map + Map simpleMap = new HashMap(); + simpleMap.put(stringValue + " (1)", longValue); + simpleMap.put(stringValue + " (2)", longValue - 1); + simpleMap.put(stringValue + " (3)", longValue - 2); + + // Simple struct + Row simpleStruct = Row.create( + doubleValue, stringValue, timestampValue, null); + + // Complex array + List> arrayOfMaps = Arrays.asList(simpleMap); + List arrayOfRows = Arrays.asList(simpleStruct); + + // Complex map + Map, Row> complexMap = new HashMap, Row>(); + complexMap.put(arrayOfRows, simpleStruct); + + // Complex struct + Row complexStruct = Row.create( + simpleStringArray, + simpleMap, + simpleStruct, + arrayOfMaps, + arrayOfRows, + complexMap, + null); + Assert.assertEquals(simpleStringArray, complexStruct.get(0)); + Assert.assertEquals(simpleMap, complexStruct.get(1)); + Assert.assertEquals(simpleStruct, complexStruct.get(2)); + Assert.assertEquals(arrayOfMaps, complexStruct.get(3)); + Assert.assertEquals(arrayOfRows, complexStruct.get(4)); + Assert.assertEquals(complexMap, complexStruct.get(5)); + Assert.assertEquals(null, complexStruct.get(6)); + + // A very complex row + Row complexRow = Row.create(arrayOfMaps, arrayOfRows, complexMap, complexStruct); + Assert.assertEquals(arrayOfMaps, complexRow.get(0)); + Assert.assertEquals(arrayOfRows, complexRow.get(1)); + Assert.assertEquals(complexMap, complexRow.get(2)); + Assert.assertEquals(complexStruct, complexRow.get(3)); + } +} diff --git a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaSideDataTypeConversionSuite.java b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaSideDataTypeConversionSuite.java new file mode 100644 index 0000000000000..d099a48a1f4b6 --- /dev/null +++ b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaSideDataTypeConversionSuite.java @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +import java.util.List; +import java.util.ArrayList; + +import org.junit.Assert; +import org.junit.Test; + +import org.apache.spark.sql.types.util.DataTypeConversions; + +public class JavaSideDataTypeConversionSuite { + public void checkDataType(DataType javaDataType) { + org.apache.spark.sql.catalyst.types.DataType scalaDataType = + DataTypeConversions.asScalaDataType(javaDataType); + DataType actual = DataTypeConversions.asJavaDataType(scalaDataType); + Assert.assertEquals(javaDataType, actual); + } + + @Test + public void createDataTypes() { + // Simple DataTypes. + checkDataType(DataType.StringType); + checkDataType(DataType.BinaryType); + checkDataType(DataType.BooleanType); + checkDataType(DataType.TimestampType); + checkDataType(DataType.DecimalType); + checkDataType(DataType.DoubleType); + checkDataType(DataType.FloatType); + checkDataType(DataType.ByteType); + checkDataType(DataType.IntegerType); + checkDataType(DataType.LongType); + checkDataType(DataType.ShortType); + + // Simple ArrayType. + DataType simpleJavaArrayType = DataType.createArrayType(DataType.StringType, true); + checkDataType(simpleJavaArrayType); + + // Simple MapType. + DataType simpleJavaMapType = DataType.createMapType(DataType.StringType, DataType.LongType); + checkDataType(simpleJavaMapType); + + // Simple StructType. + List simpleFields = new ArrayList(); + simpleFields.add(DataType.createStructField("a", DataType.DecimalType, false)); + simpleFields.add(DataType.createStructField("b", DataType.BooleanType, true)); + simpleFields.add(DataType.createStructField("c", DataType.LongType, true)); + simpleFields.add(DataType.createStructField("d", DataType.BinaryType, false)); + DataType simpleJavaStructType = DataType.createStructType(simpleFields); + checkDataType(simpleJavaStructType); + + // Complex StructType. + List complexFields = new ArrayList(); + complexFields.add(DataType.createStructField("simpleArray", simpleJavaArrayType, true)); + complexFields.add(DataType.createStructField("simpleMap", simpleJavaMapType, true)); + complexFields.add(DataType.createStructField("simpleStruct", simpleJavaStructType, true)); + complexFields.add(DataType.createStructField("boolean", DataType.BooleanType, false)); + DataType complexJavaStructType = DataType.createStructType(complexFields); + checkDataType(complexJavaStructType); + + // Complex ArrayType. + DataType complexJavaArrayType = DataType.createArrayType(complexJavaStructType, true); + checkDataType(complexJavaArrayType); + + // Complex MapType. + DataType complexJavaMapType = + DataType.createMapType(complexJavaStructType, complexJavaArrayType, false); + checkDataType(complexJavaMapType); + } + + @Test + public void illegalArgument() { + // ArrayType + try { + DataType.createArrayType(null, true); + Assert.fail(); + } catch (IllegalArgumentException expectedException) { + } + + // MapType + try { + DataType.createMapType(null, DataType.StringType); + Assert.fail(); + } catch (IllegalArgumentException expectedException) { + } + try { + DataType.createMapType(DataType.StringType, null); + Assert.fail(); + } catch (IllegalArgumentException expectedException) { + } + try { + DataType.createMapType(null, null); + Assert.fail(); + } catch (IllegalArgumentException expectedException) { + } + + // StructField + try { + DataType.createStructField(null, DataType.StringType, true); + } catch (IllegalArgumentException expectedException) { + } + try { + DataType.createStructField("name", null, true); + } catch (IllegalArgumentException expectedException) { + } + try { + DataType.createStructField(null, null, true); + } catch (IllegalArgumentException expectedException) { + } + + // StructType + try { + List simpleFields = new ArrayList(); + simpleFields.add(DataType.createStructField("a", DataType.DecimalType, false)); + simpleFields.add(DataType.createStructField("b", DataType.BooleanType, true)); + simpleFields.add(DataType.createStructField("c", DataType.LongType, true)); + simpleFields.add(null); + DataType.createStructType(simpleFields); + Assert.fail(); + } catch (IllegalArgumentException expectedException) { + } + try { + List simpleFields = new ArrayList(); + simpleFields.add(DataType.createStructField("a", DataType.DecimalType, false)); + simpleFields.add(DataType.createStructField("a", DataType.BooleanType, true)); + simpleFields.add(DataType.createStructField("c", DataType.LongType, true)); + DataType.createStructType(simpleFields); + Assert.fail(); + } catch (IllegalArgumentException expectedException) { + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index c3c0dcb1aa00b..fbf9bd9dbcdea 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -78,7 +78,7 @@ class CachedTableSuite extends QueryTest { } test("SELECT Star Cached Table") { - TestSQLContext.sql("SELECT * FROM testData").registerAsTable("selectStar") + TestSQLContext.sql("SELECT * FROM testData").registerTempTable("selectStar") TestSQLContext.cacheTable("selectStar") TestSQLContext.sql("SELECT * FROM selectStar WHERE key = 1").collect() TestSQLContext.uncacheTable("selectStar") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataTypeSuite.scala new file mode 100644 index 0000000000000..cf7d79f42db1d --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataTypeSuite.scala @@ -0,0 +1,58 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql + +import org.scalatest.FunSuite + +class DataTypeSuite extends FunSuite { + + test("construct an ArrayType") { + val array = ArrayType(StringType) + + assert(ArrayType(StringType, false) === array) + } + + test("construct an MapType") { + val map = MapType(StringType, IntegerType) + + assert(MapType(StringType, IntegerType, true) === map) + } + + test("extract fields from a StructType") { + val struct = StructType( + StructField("a", IntegerType, true) :: + StructField("b", LongType, false) :: + StructField("c", StringType, true) :: + StructField("d", FloatType, true) :: Nil) + + assert(StructField("b", LongType, false) === struct("b")) + + intercept[IllegalArgumentException] { + struct("e") + } + + val expectedStruct = StructType( + StructField("b", LongType, false) :: + StructField("d", FloatType, true) :: Nil) + + assert(expectedStruct === struct(Set("b", "d"))) + intercept[IllegalArgumentException] { + struct(Set("b", "d", "e", "f")) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/InsertIntoSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/InsertIntoSuite.scala index 4f0b85f26254b..c87d762751e6d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/InsertIntoSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/InsertIntoSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql -import java.io.File +import _root_.java.io.File /* Implicits */ import org.apache.spark.sql.test.TestSQLContext._ @@ -31,7 +31,7 @@ class InsertIntoSuite extends QueryTest { testFilePath.delete() testFilePath.deleteOnExit() val testFile = createParquetFile[TestData](testFilePath.getCanonicalPath) - testFile.registerAsTable("createAndInsertTest") + testFile.registerTempTable("createAndInsertTest") // Add some data. testData.insertInto("createAndInsertTest") @@ -86,7 +86,7 @@ class InsertIntoSuite extends QueryTest { testFilePath.delete() testFilePath.deleteOnExit() val testFile = createParquetFile[TestData](testFilePath.getCanonicalPath) - testFile.registerAsTable("createAndInsertSQLTest") + testFile.registerTempTable("createAndInsertSQLTest") sql("INSERT INTO createAndInsertSQLTest SELECT * FROM testData") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index e17ecc87fd52a..6c7697ece8c56 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -17,13 +17,17 @@ package org.apache.spark.sql +import org.scalatest.BeforeAndAfterEach + +import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.TestData._ -import org.apache.spark.sql.catalyst.plans.{LeftOuter, RightOuter, FullOuter, Inner} +import org.apache.spark.sql.catalyst.plans.JoinType +import org.apache.spark.sql.catalyst.plans.{LeftOuter, RightOuter, FullOuter, Inner, LeftSemi} import org.apache.spark.sql.execution._ import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext._ -class JoinSuite extends QueryTest { +class JoinSuite extends QueryTest with BeforeAndAfterEach { // Ensures tables are loaded. TestData @@ -36,6 +40,56 @@ class JoinSuite extends QueryTest { assert(planned.size === 1) } + test("join operator selection") { + def assertJoin(sqlString: String, c: Class[_]): Any = { + val rdd = sql(sqlString) + val physical = rdd.queryExecution.sparkPlan + val operators = physical.collect { + case j: ShuffledHashJoin => j + case j: HashOuterJoin => j + case j: LeftSemiJoinHash => j + case j: BroadcastHashJoin => j + case j: LeftSemiJoinBNL => j + case j: CartesianProduct => j + case j: BroadcastNestedLoopJoin => j + } + + assert(operators.size === 1) + if (operators(0).getClass() != c) { + fail(s"$sqlString expected operator: $c, but got ${operators(0)}\n physical: \n$physical") + } + } + + val cases1 = Seq( + ("SELECT * FROM testData left semi join testData2 ON key = a", classOf[LeftSemiJoinHash]), + ("SELECT * FROM testData left semi join testData2", classOf[LeftSemiJoinBNL]), + ("SELECT * FROM testData join testData2", classOf[CartesianProduct]), + ("SELECT * FROM testData join testData2 where key=2", classOf[CartesianProduct]), + ("SELECT * FROM testData left join testData2", classOf[CartesianProduct]), + ("SELECT * FROM testData right join testData2", classOf[CartesianProduct]), + ("SELECT * FROM testData full outer join testData2", classOf[CartesianProduct]), + ("SELECT * FROM testData left join testData2 where key=2", classOf[CartesianProduct]), + ("SELECT * FROM testData right join testData2 where key=2", classOf[CartesianProduct]), + ("SELECT * FROM testData full outer join testData2 where key=2", classOf[CartesianProduct]), + ("SELECT * FROM testData join testData2 where key>a", classOf[CartesianProduct]), + ("SELECT * FROM testData full outer join testData2 where key>a", classOf[CartesianProduct]), + ("SELECT * FROM testData join testData2 ON key = a", classOf[ShuffledHashJoin]), + ("SELECT * FROM testData join testData2 ON key = a and key=2", classOf[ShuffledHashJoin]), + ("SELECT * FROM testData join testData2 ON key = a where key=2", classOf[ShuffledHashJoin]), + ("SELECT * FROM testData left join testData2 ON key = a", classOf[HashOuterJoin]), + ("SELECT * FROM testData right join testData2 ON key = a where key=2", + classOf[HashOuterJoin]), + ("SELECT * FROM testData right join testData2 ON key = a and key=2", + classOf[HashOuterJoin]), + ("SELECT * FROM testData full outer join testData2 ON key = a", classOf[HashOuterJoin]), + ("SELECT * FROM testData join testData2 ON key = a", classOf[ShuffledHashJoin]), + ("SELECT * FROM testData join testData2 ON key = a and key=2", classOf[ShuffledHashJoin]), + ("SELECT * FROM testData join testData2 ON key = a where key=2", classOf[ShuffledHashJoin]) + // TODO add BroadcastNestedLoopJoin + ) + cases1.foreach { c => assertJoin(c._1, c._2) } + } + test("multiple-key equi-join is hash-join") { val x = testData2.as('x) val y = testData2.as('y) @@ -116,6 +170,58 @@ class JoinSuite extends QueryTest { (4, "D", 4, "d") :: (5, "E", null, null) :: (6, "F", null, null) :: Nil) + + checkAnswer( + upperCaseData.join(lowerCaseData, LeftOuter, Some('n === 'N && 'n > 1)), + (1, "A", null, null) :: + (2, "B", 2, "b") :: + (3, "C", 3, "c") :: + (4, "D", 4, "d") :: + (5, "E", null, null) :: + (6, "F", null, null) :: Nil) + + checkAnswer( + upperCaseData.join(lowerCaseData, LeftOuter, Some('n === 'N && 'N > 1)), + (1, "A", null, null) :: + (2, "B", 2, "b") :: + (3, "C", 3, "c") :: + (4, "D", 4, "d") :: + (5, "E", null, null) :: + (6, "F", null, null) :: Nil) + + checkAnswer( + upperCaseData.join(lowerCaseData, LeftOuter, Some('n === 'N && 'l > 'L)), + (1, "A", 1, "a") :: + (2, "B", 2, "b") :: + (3, "C", 3, "c") :: + (4, "D", 4, "d") :: + (5, "E", null, null) :: + (6, "F", null, null) :: Nil) + + // Make sure we are choosing left.outputPartitioning as the + // outputPartitioning for the outer join operator. + checkAnswer( + sql( + """ + |SELECT l.N, count(*) + |FROM upperCaseData l LEFT OUTER JOIN allNulls r ON (l.N = r.a) + |GROUP BY l.N + """.stripMargin), + (1, 1) :: + (2, 1) :: + (3, 1) :: + (4, 1) :: + (5, 1) :: + (6, 1) :: Nil) + + checkAnswer( + sql( + """ + |SELECT r.a, count(*) + |FROM upperCaseData l LEFT OUTER JOIN allNulls r ON (l.N = r.a) + |GROUP BY r.a + """.stripMargin), + (null, 6) :: Nil) } test("right outer join") { @@ -127,11 +233,63 @@ class JoinSuite extends QueryTest { (4, "d", 4, "D") :: (null, null, 5, "E") :: (null, null, 6, "F") :: Nil) + checkAnswer( + lowerCaseData.join(upperCaseData, RightOuter, Some('n === 'N && 'n > 1)), + (null, null, 1, "A") :: + (2, "b", 2, "B") :: + (3, "c", 3, "C") :: + (4, "d", 4, "D") :: + (null, null, 5, "E") :: + (null, null, 6, "F") :: Nil) + checkAnswer( + lowerCaseData.join(upperCaseData, RightOuter, Some('n === 'N && 'N > 1)), + (null, null, 1, "A") :: + (2, "b", 2, "B") :: + (3, "c", 3, "C") :: + (4, "d", 4, "D") :: + (null, null, 5, "E") :: + (null, null, 6, "F") :: Nil) + checkAnswer( + lowerCaseData.join(upperCaseData, RightOuter, Some('n === 'N && 'l > 'L)), + (1, "a", 1, "A") :: + (2, "b", 2, "B") :: + (3, "c", 3, "C") :: + (4, "d", 4, "D") :: + (null, null, 5, "E") :: + (null, null, 6, "F") :: Nil) + + // Make sure we are choosing right.outputPartitioning as the + // outputPartitioning for the outer join operator. + checkAnswer( + sql( + """ + |SELECT l.a, count(*) + |FROM allNulls l RIGHT OUTER JOIN upperCaseData r ON (l.a = r.N) + |GROUP BY l.a + """.stripMargin), + (null, 6) :: Nil) + + checkAnswer( + sql( + """ + |SELECT r.N, count(*) + |FROM allNulls l RIGHT OUTER JOIN upperCaseData r ON (l.a = r.N) + |GROUP BY r.N + """.stripMargin), + (1, 1) :: + (2, 1) :: + (3, 1) :: + (4, 1) :: + (5, 1) :: + (6, 1) :: Nil) } test("full outer join") { - val left = upperCaseData.where('N <= 4).as('left) - val right = upperCaseData.where('N >= 3).as('right) + upperCaseData.where('N <= 4).registerTempTable("left") + upperCaseData.where('N >= 3).registerTempTable("right") + + val left = UnresolvedRelation(None, "left", None) + val right = UnresolvedRelation(None, "right", None) checkAnswer( left.join(right, FullOuter, Some("left.N".attr === "right.N".attr)), @@ -141,5 +299,74 @@ class JoinSuite extends QueryTest { (4, "D", 4, "D") :: (null, null, 5, "E") :: (null, null, 6, "F") :: Nil) + + checkAnswer( + left.join(right, FullOuter, Some(("left.N".attr === "right.N".attr) && ("left.N".attr !== 3))), + (1, "A", null, null) :: + (2, "B", null, null) :: + (3, "C", null, null) :: + (null, null, 3, "C") :: + (4, "D", 4, "D") :: + (null, null, 5, "E") :: + (null, null, 6, "F") :: Nil) + + checkAnswer( + left.join(right, FullOuter, Some(("left.N".attr === "right.N".attr) && ("right.N".attr !== 3))), + (1, "A", null, null) :: + (2, "B", null, null) :: + (3, "C", null, null) :: + (null, null, 3, "C") :: + (4, "D", 4, "D") :: + (null, null, 5, "E") :: + (null, null, 6, "F") :: Nil) + + // Make sure we are UnknownPartitioning as the outputPartitioning for the outer join operator. + checkAnswer( + sql( + """ + |SELECT l.a, count(*) + |FROM allNulls l FULL OUTER JOIN upperCaseData r ON (l.a = r.N) + |GROUP BY l.a + """.stripMargin), + (null, 10) :: Nil) + + checkAnswer( + sql( + """ + |SELECT r.N, count(*) + |FROM allNulls l FULL OUTER JOIN upperCaseData r ON (l.a = r.N) + |GROUP BY r.N + """.stripMargin), + (1, 1) :: + (2, 1) :: + (3, 1) :: + (4, 1) :: + (5, 1) :: + (6, 1) :: + (null, 4) :: Nil) + + checkAnswer( + sql( + """ + |SELECT l.N, count(*) + |FROM upperCaseData l FULL OUTER JOIN allNulls r ON (l.N = r.a) + |GROUP BY l.N + """.stripMargin), + (1, 1) :: + (2, 1) :: + (3, 1) :: + (4, 1) :: + (5, 1) :: + (6, 1) :: + (null, 4) :: Nil) + + checkAnswer( + sql( + """ + |SELECT r.a, count(*) + |FROM upperCaseData l FULL OUTER JOIN allNulls r ON (l.N = r.a) + |GROUP BY r.a + """.stripMargin), + (null, 10) :: Nil) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 8e1e1971d968b..1fd8d27b34c59 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -45,6 +45,7 @@ class QueryTest extends PlanTest { |${rdd.queryExecution} |== Exception == |$e + |${org.apache.spark.sql.catalyst.util.stackTraceToString(e)} """.stripMargin) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala new file mode 100644 index 0000000000000..651cb735ab7d9 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala @@ -0,0 +1,46 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql + +import org.scalatest.FunSuite + +import org.apache.spark.sql.catalyst.expressions.GenericMutableRow + +class RowSuite extends FunSuite { + + test("create row") { + val expected = new GenericMutableRow(4) + expected.update(0, 2147483647) + expected.update(1, "this is a string") + expected.update(2, false) + expected.update(3, null) + val actual1 = Row(2147483647, "this is a string", false, null) + assert(expected.size === actual1.size) + assert(expected.getInt(0) === actual1.getInt(0)) + assert(expected.getString(1) === actual1.getString(1)) + assert(expected.getBoolean(2) === actual1.getBoolean(2)) + assert(expected(3) === actual1(3)) + + val actual2 = Row.fromSeq(Seq(2147483647, "this is a string", false, null)) + assert(expected.size === actual2.size) + assert(expected.getInt(0) === actual2.getInt(0)) + assert(expected.getString(1) === actual2.getString(1)) + assert(expected.getBoolean(2) === actual2.getBoolean(2)) + assert(expected(3) === actual2(3)) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala index 08293f7f0ca30..584f71b3c13d5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala @@ -29,21 +29,18 @@ class SQLConfSuite extends QueryTest { test("programmatic ways of basic setting and getting") { clear() - assert(getOption(testKey).isEmpty) - assert(getAll.toSet === Set()) + assert(getAllConfs.size === 0) - set(testKey, testVal) - assert(get(testKey) == testVal) - assert(get(testKey, testVal + "_") == testVal) - assert(getOption(testKey) == Some(testVal)) - assert(contains(testKey)) + setConf(testKey, testVal) + assert(getConf(testKey) == testVal) + assert(getConf(testKey, testVal + "_") == testVal) + assert(getAllConfs.contains(testKey)) // Tests SQLConf as accessed from a SQLContext is mutable after // the latter is initialized, unlike SparkConf inside a SparkContext. - assert(TestSQLContext.get(testKey) == testVal) - assert(TestSQLContext.get(testKey, testVal + "_") == testVal) - assert(TestSQLContext.getOption(testKey) == Some(testVal)) - assert(TestSQLContext.contains(testKey)) + assert(TestSQLContext.getConf(testKey) == testVal) + assert(TestSQLContext.getConf(testKey, testVal + "_") == testVal) + assert(TestSQLContext.getAllConfs.contains(testKey)) clear() } @@ -51,23 +48,28 @@ class SQLConfSuite extends QueryTest { test("parse SQL set commands") { clear() sql(s"set $testKey=$testVal") - assert(get(testKey, testVal + "_") == testVal) - assert(TestSQLContext.get(testKey, testVal + "_") == testVal) + assert(getConf(testKey, testVal + "_") == testVal) + assert(TestSQLContext.getConf(testKey, testVal + "_") == testVal) - sql("set mapred.reduce.tasks=20") - assert(get("mapred.reduce.tasks", "0") == "20") - sql("set mapred.reduce.tasks = 40") - assert(get("mapred.reduce.tasks", "0") == "40") + sql("set some.property=20") + assert(getConf("some.property", "0") == "20") + sql("set some.property = 40") + assert(getConf("some.property", "0") == "40") val key = "spark.sql.key" val vs = "val0,val_1,val2.3,my_table" sql(s"set $key=$vs") - assert(get(key, "0") == vs) + assert(getConf(key, "0") == vs) sql(s"set $key=") - assert(get(key, "0") == "") + assert(getConf(key, "0") == "") clear() } + test("deprecated property") { + clear() + sql(s"set ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS}=10") + assert(getConf(SQLConf.SHUFFLE_PARTITIONS) == "10") + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 6736189c96d4b..9b2a36d33fca7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -17,9 +17,7 @@ package org.apache.spark.sql -import org.apache.spark.sql.catalyst.analysis.EliminateAnalysisOperators import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.columnar.{InMemoryColumnarTableScan, InMemoryRelation} import org.apache.spark.sql.test._ /* Implicits */ @@ -424,26 +422,107 @@ class SQLQuerySuite extends QueryTest { sql(s"SET $testKey=$testVal") checkAnswer( sql("SET"), - Seq(Seq(testKey, testVal)) + Seq(Seq(s"$testKey=$testVal")) ) sql(s"SET ${testKey + testKey}=${testVal + testVal}") checkAnswer( sql("set"), Seq( - Seq(testKey, testVal), - Seq(testKey + testKey, testVal + testVal)) + Seq(s"$testKey=$testVal"), + Seq(s"${testKey + testKey}=${testVal + testVal}")) ) // "set key" checkAnswer( sql(s"SET $testKey"), - Seq(Seq(testKey, testVal)) + Seq(Seq(s"$testKey=$testVal")) ) checkAnswer( sql(s"SET $nonexistentKey"), - Seq(Seq(nonexistentKey, "")) + Seq(Seq(s"$nonexistentKey=")) ) clear() } + + test("apply schema") { + val schema1 = StructType( + StructField("f1", IntegerType, false) :: + StructField("f2", StringType, false) :: + StructField("f3", BooleanType, false) :: + StructField("f4", IntegerType, true) :: Nil) + + val rowRDD1 = unparsedStrings.map { r => + val values = r.split(",").map(_.trim) + val v4 = try values(3).toInt catch { + case _: NumberFormatException => null + } + Row(values(0).toInt, values(1), values(2).toBoolean, v4) + } + + val schemaRDD1 = applySchema(rowRDD1, schema1) + schemaRDD1.registerTempTable("applySchema1") + checkAnswer( + sql("SELECT * FROM applySchema1"), + (1, "A1", true, null) :: + (2, "B2", false, null) :: + (3, "C3", true, null) :: + (4, "D4", true, 2147483644) :: Nil) + + checkAnswer( + sql("SELECT f1, f4 FROM applySchema1"), + (1, null) :: + (2, null) :: + (3, null) :: + (4, 2147483644) :: Nil) + + val schema2 = StructType( + StructField("f1", StructType( + StructField("f11", IntegerType, false) :: + StructField("f12", BooleanType, false) :: Nil), false) :: + StructField("f2", MapType(StringType, IntegerType, true), false) :: Nil) + + val rowRDD2 = unparsedStrings.map { r => + val values = r.split(",").map(_.trim) + val v4 = try values(3).toInt catch { + case _: NumberFormatException => null + } + Row(Row(values(0).toInt, values(2).toBoolean), Map(values(1) -> v4)) + } + + val schemaRDD2 = applySchema(rowRDD2, schema2) + schemaRDD2.registerTempTable("applySchema2") + checkAnswer( + sql("SELECT * FROM applySchema2"), + (Seq(1, true), Map("A1" -> null)) :: + (Seq(2, false), Map("B2" -> null)) :: + (Seq(3, true), Map("C3" -> null)) :: + (Seq(4, true), Map("D4" -> 2147483644)) :: Nil) + + checkAnswer( + sql("SELECT f1.f11, f2['D4'] FROM applySchema2"), + (1, null) :: + (2, null) :: + (3, null) :: + (4, 2147483644) :: Nil) + + // The value of a MapType column can be a mutable map. + val rowRDD3 = unparsedStrings.map { r => + val values = r.split(",").map(_.trim) + val v4 = try values(3).toInt catch { + case _: NumberFormatException => null + } + Row(Row(values(0).toInt, values(2).toBoolean), scala.collection.mutable.Map(values(1) -> v4)) + } + + val schemaRDD3 = applySchema(rowRDD3, schema2) + schemaRDD3.registerTempTable("applySchema3") + + checkAnswer( + sql("SELECT f1.f11, f2['D4'] FROM applySchema3"), + (1, null) :: + (2, null) :: + (3, null) :: + (4, 2147483644) :: Nil) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala index f2934da9a031d..5b84c658db942 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala @@ -61,7 +61,7 @@ class ScalaReflectionRelationSuite extends FunSuite { val data = ReflectData("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true, BigDecimal(1), new Timestamp(12345), Seq(1,2,3)) val rdd = sparkContext.parallelize(data :: Nil) - rdd.registerAsTable("reflectData") + rdd.registerTempTable("reflectData") assert(sql("SELECT * FROM reflectData").collect().head === data.productIterator.toSeq) } @@ -69,7 +69,7 @@ class ScalaReflectionRelationSuite extends FunSuite { test("query case class RDD with nulls") { val data = NullReflectData(null, null, null, null, null, null, null) val rdd = sparkContext.parallelize(data :: Nil) - rdd.registerAsTable("reflectNullData") + rdd.registerTempTable("reflectNullData") assert(sql("SELECT * FROM reflectNullData").collect().head === Seq.fill(7)(null)) } @@ -77,7 +77,7 @@ class ScalaReflectionRelationSuite extends FunSuite { test("query case class RDD with Nones") { val data = OptionalReflectData(None, None, None, None, None, None, None) val rdd = sparkContext.parallelize(data :: Nil) - rdd.registerAsTable("reflectOptionalData") + rdd.registerTempTable("reflectOptionalData") assert(sql("SELECT * FROM reflectOptionalData").collect().head === Seq.fill(7)(null)) } @@ -85,7 +85,7 @@ class ScalaReflectionRelationSuite extends FunSuite { // Equality is broken for Arrays, so we test that separately. test("query binary data") { val rdd = sparkContext.parallelize(ReflectBinary(Array[Byte](1)) :: Nil) - rdd.registerAsTable("reflectBinary") + rdd.registerTempTable("reflectBinary") val result = sql("SELECT data FROM reflectBinary").collect().head(0).asInstanceOf[Array[Byte]] assert(result.toSeq === Seq[Byte](1)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala index 330b20b315d63..c3ec82fb69778 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala @@ -17,18 +17,20 @@ package org.apache.spark.sql +import java.sql.Timestamp + import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.test._ /* Implicits */ -import TestSQLContext._ +import org.apache.spark.sql.test.TestSQLContext._ case class TestData(key: Int, value: String) object TestData { val testData: SchemaRDD = TestSQLContext.sparkContext.parallelize( (1 to 100).map(i => TestData(i, i.toString))) - testData.registerAsTable("testData") + testData.registerTempTable("testData") case class LargeAndSmallInts(a: Int, b: Int) val largeAndSmallInts: SchemaRDD = @@ -39,8 +41,8 @@ object TestData { LargeAndSmallInts(2, 2) :: LargeAndSmallInts(2147483646, 1) :: LargeAndSmallInts(3, 2) :: Nil) - largeAndSmallInts.registerAsTable("largeAndSmallInts") - + largeAndSmallInts.registerTempTable("largeAndSmallInts") + case class TestData2(a: Int, b: Int) val testData2: SchemaRDD = TestSQLContext.sparkContext.parallelize( @@ -50,7 +52,7 @@ object TestData { TestData2(2, 2) :: TestData2(3, 1) :: TestData2(3, 2) :: Nil) - testData2.registerAsTable("testData2") + testData2.registerTempTable("testData2") // TODO: There is no way to express null primitives as case classes currently... val testData3 = @@ -69,7 +71,7 @@ object TestData { UpperCaseData(4, "D") :: UpperCaseData(5, "E") :: UpperCaseData(6, "F") :: Nil) - upperCaseData.registerAsTable("upperCaseData") + upperCaseData.registerTempTable("upperCaseData") case class LowerCaseData(n: Int, l: String) val lowerCaseData = @@ -78,14 +80,14 @@ object TestData { LowerCaseData(2, "b") :: LowerCaseData(3, "c") :: LowerCaseData(4, "d") :: Nil) - lowerCaseData.registerAsTable("lowerCaseData") + lowerCaseData.registerTempTable("lowerCaseData") case class ArrayData(data: Seq[Int], nestedData: Seq[Seq[Int]]) val arrayData = TestSQLContext.sparkContext.parallelize( ArrayData(Seq(1,2,3), Seq(Seq(1,2,3))) :: ArrayData(Seq(2,3,4), Seq(Seq(2,3,4))) :: Nil) - arrayData.registerAsTable("arrayData") + arrayData.registerTempTable("arrayData") case class MapData(data: Map[Int, String]) val mapData = @@ -95,18 +97,18 @@ object TestData { MapData(Map(1 -> "a3", 2 -> "b3", 3 -> "c3")) :: MapData(Map(1 -> "a4", 2 -> "b4")) :: MapData(Map(1 -> "a5")) :: Nil) - mapData.registerAsTable("mapData") + mapData.registerTempTable("mapData") case class StringData(s: String) val repeatedData = TestSQLContext.sparkContext.parallelize(List.fill(2)(StringData("test"))) - repeatedData.registerAsTable("repeatedData") + repeatedData.registerTempTable("repeatedData") val nullableRepeatedData = TestSQLContext.sparkContext.parallelize( List.fill(2)(StringData(null)) ++ List.fill(2)(StringData("test"))) - nullableRepeatedData.registerAsTable("nullableRepeatedData") + nullableRepeatedData.registerTempTable("nullableRepeatedData") case class NullInts(a: Integer) val nullInts = @@ -116,7 +118,15 @@ object TestData { NullInts(3) :: NullInts(null) :: Nil ) - nullInts.registerAsTable("nullInts") + nullInts.registerTempTable("nullInts") + + val allNulls = + TestSQLContext.sparkContext.parallelize( + NullInts(null) :: + NullInts(null) :: + NullInts(null) :: + NullInts(null) :: Nil) + allNulls.registerTempTable("allNulls") case class NullStrings(n: Int, s: String) val nullStrings = @@ -124,8 +134,21 @@ object TestData { NullStrings(1, "abc") :: NullStrings(2, "ABC") :: NullStrings(3, null) :: Nil) - nullStrings.registerAsTable("nullStrings") + nullStrings.registerTempTable("nullStrings") case class TableName(tableName: String) - TestSQLContext.sparkContext.parallelize(TableName("test") :: Nil).registerAsTable("tableName") + TestSQLContext.sparkContext.parallelize(TableName("test") :: Nil).registerTempTable("tableName") + + val unparsedStrings = + TestSQLContext.sparkContext.parallelize( + "1, A1, true, null" :: + "2, B2, false, null" :: + "3, C3, true, null" :: + "4, D4, true, 2147483644" :: Nil) + + case class TimestampField(time: Timestamp) + val timestamps = TestSQLContext.sparkContext.parallelize((1 to 3).map { i => + TimestampField(new Timestamp(i)) + }) + timestamps.registerTempTable("timestamps") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala new file mode 100644 index 0000000000000..76aa9b0081d7e --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.test._ + +/* Implicits */ +import TestSQLContext._ + +class UDFSuite extends QueryTest { + + test("Simple UDF") { + registerFunction("strLenScala", (_: String).length) + assert(sql("SELECT strLenScala('test')").first().getInt(0) === 4) + } + + test("TwoArgument UDF") { + registerFunction("strLenScala", (_: String).length + (_:Int)) + assert(sql("SELECT strLenScala('test', 1)").first().getInt(0) === 5) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/api/java/JavaSQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/api/java/JavaSQLSuite.scala index 020baf0c7ec6f..203ff847e94cc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/api/java/JavaSQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/api/java/JavaSQLSuite.scala @@ -59,7 +59,7 @@ class JavaSQLSuite extends FunSuite { val rdd = javaCtx.parallelize(person :: Nil) val schemaRDD = javaSqlCtx.applySchema(rdd, classOf[PersonBean]) - schemaRDD.registerAsTable("people") + schemaRDD.registerTempTable("people") javaSqlCtx.sql("SELECT * FROM people").collect() } @@ -76,7 +76,7 @@ class JavaSQLSuite extends FunSuite { val rdd = javaCtx.parallelize(bean :: Nil) val schemaRDD = javaSqlCtx.applySchema(rdd, classOf[AllTypesBean]) - schemaRDD.registerAsTable("allTypes") + schemaRDD.registerTempTable("allTypes") assert( javaSqlCtx.sql( @@ -101,7 +101,7 @@ class JavaSQLSuite extends FunSuite { val rdd = javaCtx.parallelize(bean :: Nil) val schemaRDD = javaSqlCtx.applySchema(rdd, classOf[AllTypesBean]) - schemaRDD.registerAsTable("allTypes") + schemaRDD.registerTempTable("allTypes") assert( javaSqlCtx.sql( @@ -127,7 +127,7 @@ class JavaSQLSuite extends FunSuite { var schemaRDD = javaSqlCtx.jsonRDD(rdd) - schemaRDD.registerAsTable("jsonTable1") + schemaRDD.registerTempTable("jsonTable1") assert( javaSqlCtx.sql("select * from jsonTable1").collect.head.row === @@ -144,7 +144,7 @@ class JavaSQLSuite extends FunSuite { rdd.saveAsTextFile(path) schemaRDD = javaSqlCtx.jsonFile(path) - schemaRDD.registerAsTable("jsonTable2") + schemaRDD.registerTempTable("jsonTable2") assert( javaSqlCtx.sql("select * from jsonTable2").collect.head.row === diff --git a/sql/core/src/test/scala/org/apache/spark/sql/api/java/ScalaSideDataTypeConversionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/api/java/ScalaSideDataTypeConversionSuite.scala new file mode 100644 index 0000000000000..ff1debff0f8c1 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/api/java/ScalaSideDataTypeConversionSuite.scala @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java + +import org.apache.spark.sql.types.util.DataTypeConversions +import org.scalatest.FunSuite + +import org.apache.spark.sql.{DataType => SDataType, StructField => SStructField} +import org.apache.spark.sql.{StructType => SStructType} +import DataTypeConversions._ + +class ScalaSideDataTypeConversionSuite extends FunSuite { + + def checkDataType(scalaDataType: SDataType) { + val javaDataType = asJavaDataType(scalaDataType) + val actual = asScalaDataType(javaDataType) + assert(scalaDataType === actual, s"Converted data type ${actual} " + + s"does not equal the expected data type ${scalaDataType}") + } + + test("convert data types") { + // Simple DataTypes. + checkDataType(org.apache.spark.sql.StringType) + checkDataType(org.apache.spark.sql.BinaryType) + checkDataType(org.apache.spark.sql.BooleanType) + checkDataType(org.apache.spark.sql.TimestampType) + checkDataType(org.apache.spark.sql.DecimalType) + checkDataType(org.apache.spark.sql.DoubleType) + checkDataType(org.apache.spark.sql.FloatType) + checkDataType(org.apache.spark.sql.ByteType) + checkDataType(org.apache.spark.sql.IntegerType) + checkDataType(org.apache.spark.sql.LongType) + checkDataType(org.apache.spark.sql.ShortType) + + // Simple ArrayType. + val simpleScalaArrayType = + org.apache.spark.sql.ArrayType(org.apache.spark.sql.StringType, true) + checkDataType(simpleScalaArrayType) + + // Simple MapType. + val simpleScalaMapType = + org.apache.spark.sql.MapType(org.apache.spark.sql.StringType, org.apache.spark.sql.LongType) + checkDataType(simpleScalaMapType) + + // Simple StructType. + val simpleScalaStructType = SStructType( + SStructField("a", org.apache.spark.sql.DecimalType, false) :: + SStructField("b", org.apache.spark.sql.BooleanType, true) :: + SStructField("c", org.apache.spark.sql.LongType, true) :: + SStructField("d", org.apache.spark.sql.BinaryType, false) :: Nil) + checkDataType(simpleScalaStructType) + + // Complex StructType. + val complexScalaStructType = SStructType( + SStructField("simpleArray", simpleScalaArrayType, true) :: + SStructField("simpleMap", simpleScalaMapType, true) :: + SStructField("simpleStruct", simpleScalaStructType, true) :: + SStructField("boolean", org.apache.spark.sql.BooleanType, false) :: Nil) + checkDataType(complexScalaStructType) + + // Complex ArrayType. + val complexScalaArrayType = + org.apache.spark.sql.ArrayType(complexScalaStructType, true) + checkDataType(complexScalaArrayType) + + // Complex MapType. + val complexScalaMapType = + org.apache.spark.sql.MapType(complexScalaStructType, complexScalaArrayType, false) + checkDataType(complexScalaMapType) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala index 829342215e691..75f653f3280bd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala @@ -22,7 +22,7 @@ import java.sql.Timestamp import org.scalatest.FunSuite -import org.apache.spark.sql.Logging +import org.apache.spark.Logging import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.columnar.ColumnarTestUtils._ import org.apache.spark.sql.execution.SparkSqlSerializer @@ -166,7 +166,7 @@ class ColumnTypeSuite extends FunSuite with Logging { buffer.rewind() seq.foreach { expected => - logger.info("buffer = " + buffer + ", expected = " + expected) + logInfo("buffer = " + buffer + ", expected = " + expected) val extracted = columnType.extract(buffer) assert( expected === extracted, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala index 86727b93f3659..b561b44ad7ee2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala @@ -73,4 +73,16 @@ class InMemoryColumnarQuerySuite extends QueryTest { sql("SELECT * FROM nullableRepeatedData"), nullableRepeatedData.collect().toSeq) } + + test("SPARK-2729 regression: timestamp data type") { + checkAnswer( + sql("SELECT time FROM timestamps"), + timestamps.collect().toSeq) + + TestSQLContext.cacheTable("timestamps") + + checkAnswer( + sql("SELECT time FROM timestamps"), + timestamps.collect().toSeq) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala index 35ab14cbc353d..3baa6f8ec0c83 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala @@ -41,7 +41,7 @@ object TestNullableColumnAccessor { class NullableColumnAccessorSuite extends FunSuite { import ColumnarTestUtils._ - Seq(INT, LONG, SHORT, BOOLEAN, BYTE, STRING, DOUBLE, FLOAT, BINARY, GENERIC).foreach { + Seq(INT, LONG, SHORT, BOOLEAN, BYTE, STRING, DOUBLE, FLOAT, BINARY, GENERIC, TIMESTAMP).foreach { testNullableColumnAccessor(_) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala index d8898527baa39..dc813fe146c47 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala @@ -37,7 +37,7 @@ object TestNullableColumnBuilder { class NullableColumnBuilderSuite extends FunSuite { import ColumnarTestUtils._ - Seq(INT, LONG, SHORT, BOOLEAN, BYTE, STRING, DOUBLE, FLOAT, BINARY, GENERIC).foreach { + Seq(INT, LONG, SHORT, BOOLEAN, BYTE, STRING, DOUBLE, FLOAT, BINARY, GENERIC, TIMESTAMP).foreach { testNullableColumnBuilder(_) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/TestCompressibleColumnBuilder.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/TestCompressibleColumnBuilder.scala index 6d688ea95cfc0..72c19fa31d980 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/TestCompressibleColumnBuilder.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/TestCompressibleColumnBuilder.scala @@ -42,4 +42,3 @@ object TestCompressibleColumnBuilder { builder } } - diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 215618e852eb2..76b1724471442 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -39,22 +39,22 @@ class PlannerSuite extends FunSuite { test("count is partially aggregated") { val query = testData.groupBy('value)(Count('key)).queryExecution.analyzed - val planned = PartialAggregation(query).head - val aggregations = planned.collect { case a: Aggregate => a } + val planned = HashAggregation(query).head + val aggregations = planned.collect { case n if n.nodeName contains "Aggregate" => n } assert(aggregations.size === 2) } test("count distinct is not partially aggregated") { val query = testData.groupBy('value)(CountDistinct('key :: Nil)).queryExecution.analyzed - val planned = PartialAggregation(query) + val planned = HashAggregation(query) assert(planned.isEmpty) } test("mixed aggregates are not partially aggregated") { val query = testData.groupBy('value)(Count('value), CountDistinct('key :: Nil)).queryExecution.analyzed - val planned = PartialAggregation(query) + val planned = HashAggregation(query) assert(planned.isEmpty) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/TgfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/TgfSuite.scala index e55648b8ed15a..2cab5e0c44d92 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/TgfSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/TgfSuite.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.test.TestSQLContext._ * Note: this is only a rough example of how TGFs can be expressed, the final version will likely * involve a lot more sugar for cleaner use in Scala/Java/etc. */ -case class ExampleTGF(input: Seq[Attribute] = Seq('name, 'age)) extends Generator { +case class ExampleTGF(input: Seq[Expression] = Seq('name, 'age)) extends Generator { def children = input protected def makeOutput() = 'nameAndAge.string :: Nil diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala index e765cfc83a397..75c0589eb208e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala @@ -17,16 +17,12 @@ package org.apache.spark.sql.json -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} -import org.apache.spark.sql.catalyst.plans.logical.LeafNode import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.json.JsonRDD.{enforceCorrectType, compatibleType} import org.apache.spark.sql.QueryTest import org.apache.spark.sql.test.TestSQLContext._ -protected case class Schema(output: Seq[Attribute]) extends LeafNode - class JsonSuite extends QueryTest { import TestJsonData._ TestJsonData @@ -127,6 +123,18 @@ class JsonSuite extends QueryTest { checkDataType(ArrayType(IntegerType), ArrayType(LongType), ArrayType(LongType)) checkDataType(ArrayType(IntegerType), ArrayType(StringType), ArrayType(StringType)) checkDataType(ArrayType(IntegerType), StructType(Nil), StringType) + checkDataType( + ArrayType(IntegerType, true), ArrayType(IntegerType), ArrayType(IntegerType, true)) + checkDataType( + ArrayType(IntegerType, true), ArrayType(IntegerType, false), ArrayType(IntegerType, true)) + checkDataType( + ArrayType(IntegerType, true), ArrayType(IntegerType, true), ArrayType(IntegerType, true)) + checkDataType( + ArrayType(IntegerType, false), ArrayType(IntegerType), ArrayType(IntegerType, false)) + checkDataType( + ArrayType(IntegerType, false), ArrayType(IntegerType, false), ArrayType(IntegerType, false)) + checkDataType( + ArrayType(IntegerType, false), ArrayType(IntegerType, false), ArrayType(IntegerType)) // StructType checkDataType(StructType(Nil), StructType(Nil), StructType(Nil)) @@ -164,18 +172,18 @@ class JsonSuite extends QueryTest { test("Primitive field and type inferring") { val jsonSchemaRDD = jsonRDD(primitiveFieldAndType) - val expectedSchema = - AttributeReference("bigInteger", DecimalType, true)() :: - AttributeReference("boolean", BooleanType, true)() :: - AttributeReference("double", DoubleType, true)() :: - AttributeReference("integer", IntegerType, true)() :: - AttributeReference("long", LongType, true)() :: - AttributeReference("null", StringType, true)() :: - AttributeReference("string", StringType, true)() :: Nil + val expectedSchema = StructType( + StructField("bigInteger", DecimalType, true) :: + StructField("boolean", BooleanType, true) :: + StructField("double", DoubleType, true) :: + StructField("integer", IntegerType, true) :: + StructField("long", LongType, true) :: + StructField("null", StringType, true) :: + StructField("string", StringType, true) :: Nil) - comparePlans(Schema(expectedSchema), Schema(jsonSchemaRDD.logicalPlan.output)) + assert(expectedSchema === jsonSchemaRDD.schema) - jsonSchemaRDD.registerAsTable("jsonTable") + jsonSchemaRDD.registerTempTable("jsonTable") checkAnswer( sql("select * from jsonTable"), @@ -192,29 +200,30 @@ class JsonSuite extends QueryTest { test("Complex field and type inferring") { val jsonSchemaRDD = jsonRDD(complexFieldAndType) - val expectedSchema = - AttributeReference("arrayOfArray1", ArrayType(ArrayType(StringType)), true)() :: - AttributeReference("arrayOfArray2", ArrayType(ArrayType(DoubleType)), true)() :: - AttributeReference("arrayOfBigInteger", ArrayType(DecimalType), true)() :: - AttributeReference("arrayOfBoolean", ArrayType(BooleanType), true)() :: - AttributeReference("arrayOfDouble", ArrayType(DoubleType), true)() :: - AttributeReference("arrayOfInteger", ArrayType(IntegerType), true)() :: - AttributeReference("arrayOfLong", ArrayType(LongType), true)() :: - AttributeReference("arrayOfNull", ArrayType(StringType), true)() :: - AttributeReference("arrayOfString", ArrayType(StringType), true)() :: - AttributeReference("arrayOfStruct", ArrayType( - StructType(StructField("field1", BooleanType, true) :: - StructField("field2", StringType, true) :: Nil)), true)() :: - AttributeReference("struct", StructType( - StructField("field1", BooleanType, true) :: - StructField("field2", DecimalType, true) :: Nil), true)() :: - AttributeReference("structWithArrayFields", StructType( + val expectedSchema = StructType( + StructField("arrayOfArray1", ArrayType(ArrayType(StringType)), true) :: + StructField("arrayOfArray2", ArrayType(ArrayType(DoubleType)), true) :: + StructField("arrayOfBigInteger", ArrayType(DecimalType), true) :: + StructField("arrayOfBoolean", ArrayType(BooleanType), true) :: + StructField("arrayOfDouble", ArrayType(DoubleType), true) :: + StructField("arrayOfInteger", ArrayType(IntegerType), true) :: + StructField("arrayOfLong", ArrayType(LongType), true) :: + StructField("arrayOfNull", ArrayType(StringType, true), true) :: + StructField("arrayOfString", ArrayType(StringType), true) :: + StructField("arrayOfStruct", ArrayType( + StructType( + StructField("field1", BooleanType, true) :: + StructField("field2", StringType, true) :: Nil)), true) :: + StructField("struct", StructType( + StructField("field1", BooleanType, true) :: + StructField("field2", DecimalType, true) :: Nil), true) :: + StructField("structWithArrayFields", StructType( StructField("field1", ArrayType(IntegerType), true) :: - StructField("field2", ArrayType(StringType), true) :: Nil), true)() :: Nil + StructField("field2", ArrayType(StringType), true) :: Nil), true) :: Nil) - comparePlans(Schema(expectedSchema), Schema(jsonSchemaRDD.logicalPlan.output)) + assert(expectedSchema === jsonSchemaRDD.schema) - jsonSchemaRDD.registerAsTable("jsonTable") + jsonSchemaRDD.registerTempTable("jsonTable") // Access elements of a primitive array. checkAnswer( @@ -282,7 +291,7 @@ class JsonSuite extends QueryTest { ignore("Complex field and type inferring (Ignored)") { val jsonSchemaRDD = jsonRDD(complexFieldAndType) - jsonSchemaRDD.registerAsTable("jsonTable") + jsonSchemaRDD.registerTempTable("jsonTable") // Right now, "field1" and "field2" are treated as aliases. We should fix it. checkAnswer( @@ -301,17 +310,17 @@ class JsonSuite extends QueryTest { test("Type conflict in primitive field values") { val jsonSchemaRDD = jsonRDD(primitiveFieldValueTypeConflict) - val expectedSchema = - AttributeReference("num_bool", StringType, true)() :: - AttributeReference("num_num_1", LongType, true)() :: - AttributeReference("num_num_2", DecimalType, true)() :: - AttributeReference("num_num_3", DoubleType, true)() :: - AttributeReference("num_str", StringType, true)() :: - AttributeReference("str_bool", StringType, true)() :: Nil + val expectedSchema = StructType( + StructField("num_bool", StringType, true) :: + StructField("num_num_1", LongType, true) :: + StructField("num_num_2", DecimalType, true) :: + StructField("num_num_3", DoubleType, true) :: + StructField("num_str", StringType, true) :: + StructField("str_bool", StringType, true) :: Nil) - comparePlans(Schema(expectedSchema), Schema(jsonSchemaRDD.logicalPlan.output)) + assert(expectedSchema === jsonSchemaRDD.schema) - jsonSchemaRDD.registerAsTable("jsonTable") + jsonSchemaRDD.registerTempTable("jsonTable") checkAnswer( sql("select * from jsonTable"), @@ -365,7 +374,7 @@ class JsonSuite extends QueryTest { ignore("Type conflict in primitive field values (Ignored)") { val jsonSchemaRDD = jsonRDD(primitiveFieldValueTypeConflict) - jsonSchemaRDD.registerAsTable("jsonTable") + jsonSchemaRDD.registerTempTable("jsonTable") // Right now, the analyzer does not promote strings in a boolean expreesion. // Number and Boolean conflict: resolve the type as boolean in this query. @@ -426,17 +435,17 @@ class JsonSuite extends QueryTest { test("Type conflict in complex field values") { val jsonSchemaRDD = jsonRDD(complexFieldValueTypeConflict) - val expectedSchema = - AttributeReference("array", ArrayType(IntegerType), true)() :: - AttributeReference("num_struct", StringType, true)() :: - AttributeReference("str_array", StringType, true)() :: - AttributeReference("struct", StructType( - StructField("field", StringType, true) :: Nil), true)() :: - AttributeReference("struct_array", StringType, true)() :: Nil + val expectedSchema = StructType( + StructField("array", ArrayType(IntegerType), true) :: + StructField("num_struct", StringType, true) :: + StructField("str_array", StringType, true) :: + StructField("struct", StructType( + StructField("field", StringType, true) :: Nil), true) :: + StructField("struct_array", StringType, true) :: Nil) - comparePlans(Schema(expectedSchema), Schema(jsonSchemaRDD.logicalPlan.output)) + assert(expectedSchema === jsonSchemaRDD.schema) - jsonSchemaRDD.registerAsTable("jsonTable") + jsonSchemaRDD.registerTempTable("jsonTable") checkAnswer( sql("select * from jsonTable"), @@ -450,14 +459,14 @@ class JsonSuite extends QueryTest { test("Type conflict in array elements") { val jsonSchemaRDD = jsonRDD(arrayElementTypeConflict) - val expectedSchema = - AttributeReference("array1", ArrayType(StringType), true)() :: - AttributeReference("array2", ArrayType(StructType( - StructField("field", LongType, true) :: Nil)), true)() :: Nil + val expectedSchema = StructType( + StructField("array1", ArrayType(StringType, true), true) :: + StructField("array2", ArrayType(StructType( + StructField("field", LongType, true) :: Nil)), true) :: Nil) - comparePlans(Schema(expectedSchema), Schema(jsonSchemaRDD.logicalPlan.output)) + assert(expectedSchema === jsonSchemaRDD.schema) - jsonSchemaRDD.registerAsTable("jsonTable") + jsonSchemaRDD.registerTempTable("jsonTable") checkAnswer( sql("select * from jsonTable"), @@ -475,17 +484,17 @@ class JsonSuite extends QueryTest { test("Handling missing fields") { val jsonSchemaRDD = jsonRDD(missingFields) - val expectedSchema = - AttributeReference("a", BooleanType, true)() :: - AttributeReference("b", LongType, true)() :: - AttributeReference("c", ArrayType(IntegerType), true)() :: - AttributeReference("d", StructType( - StructField("field", BooleanType, true) :: Nil), true)() :: - AttributeReference("e", StringType, true)() :: Nil + val expectedSchema = StructType( + StructField("a", BooleanType, true) :: + StructField("b", LongType, true) :: + StructField("c", ArrayType(IntegerType), true) :: + StructField("d", StructType( + StructField("field", BooleanType, true) :: Nil), true) :: + StructField("e", StringType, true) :: Nil) - comparePlans(Schema(expectedSchema), Schema(jsonSchemaRDD.logicalPlan.output)) + assert(expectedSchema === jsonSchemaRDD.schema) - jsonSchemaRDD.registerAsTable("jsonTable") + jsonSchemaRDD.registerTempTable("jsonTable") } test("Loading a JSON dataset from a text file") { @@ -494,18 +503,18 @@ class JsonSuite extends QueryTest { primitiveFieldAndType.map(record => record.replaceAll("\n", " ")).saveAsTextFile(path) val jsonSchemaRDD = jsonFile(path) - val expectedSchema = - AttributeReference("bigInteger", DecimalType, true)() :: - AttributeReference("boolean", BooleanType, true)() :: - AttributeReference("double", DoubleType, true)() :: - AttributeReference("integer", IntegerType, true)() :: - AttributeReference("long", LongType, true)() :: - AttributeReference("null", StringType, true)() :: - AttributeReference("string", StringType, true)() :: Nil + val expectedSchema = StructType( + StructField("bigInteger", DecimalType, true) :: + StructField("boolean", BooleanType, true) :: + StructField("double", DoubleType, true) :: + StructField("integer", IntegerType, true) :: + StructField("long", LongType, true) :: + StructField("null", StringType, true) :: + StructField("string", StringType, true) :: Nil) - comparePlans(Schema(expectedSchema), Schema(jsonSchemaRDD.logicalPlan.output)) + assert(expectedSchema === jsonSchemaRDD.schema) - jsonSchemaRDD.registerAsTable("jsonTable") + jsonSchemaRDD.registerTempTable("jsonTable") checkAnswer( sql("select * from jsonTable"), @@ -518,4 +527,53 @@ class JsonSuite extends QueryTest { "this is a simple string.") :: Nil ) } + + test("Applying schemas") { + val file = getTempFilePath("json") + val path = file.toString + primitiveFieldAndType.map(record => record.replaceAll("\n", " ")).saveAsTextFile(path) + + val schema = StructType( + StructField("bigInteger", DecimalType, true) :: + StructField("boolean", BooleanType, true) :: + StructField("double", DoubleType, true) :: + StructField("integer", IntegerType, true) :: + StructField("long", LongType, true) :: + StructField("null", StringType, true) :: + StructField("string", StringType, true) :: Nil) + + val jsonSchemaRDD1 = jsonFile(path, schema) + + assert(schema === jsonSchemaRDD1.schema) + + jsonSchemaRDD1.registerTempTable("jsonTable1") + + checkAnswer( + sql("select * from jsonTable1"), + (BigDecimal("92233720368547758070"), + true, + 1.7976931348623157E308, + 10, + 21474836470L, + null, + "this is a simple string.") :: Nil + ) + + val jsonSchemaRDD2 = jsonRDD(primitiveFieldAndType, schema) + + assert(schema === jsonSchemaRDD2.schema) + + jsonSchemaRDD2.registerTempTable("jsonTable2") + + checkAnswer( + sql("select * from jsonTable2"), + (BigDecimal("92233720368547758070"), + true, + 1.7976931348623157E308, + 10, + 21474836470L, + null, + "this is a simple string.") :: Nil + ) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala index 3c911e9a4e7b1..9933575038bd3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala @@ -25,6 +25,7 @@ import parquet.schema.MessageTypeParser import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.mapreduce.Job + import org.apache.spark.SparkContext import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.{SqlLexical, SqlParser} @@ -32,6 +33,7 @@ import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedAttribute} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.types.{BooleanType, IntegerType} import org.apache.spark.sql.catalyst.util.getTempFilePath +import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext._ import org.apache.spark.util.Utils @@ -99,9 +101,9 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA ParquetTestData.writeNestedFile3() ParquetTestData.writeNestedFile4() testRDD = parquetFile(ParquetTestData.testDir.toString) - testRDD.registerAsTable("testsource") + testRDD.registerTempTable("testsource") parquetFile(ParquetTestData.testFilterDir.toString) - .registerAsTable("testfiltersource") + .registerTempTable("testfiltersource") } override def afterAll() { @@ -207,18 +209,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA } test("Projection of simple Parquet file") { - val scanner = new ParquetTableScan( - ParquetTestData.testData.output, - ParquetTestData.testData, - Seq())(TestSQLContext) - val projected = scanner.pruneColumns(ParquetTypesConverter - .convertToAttributes(MessageTypeParser - .parseMessageType(ParquetTestData.subTestSchema))) - assert(projected.output.size === 2) - val result = projected - .execute() - .map(_.copy()) - .collect() + val result = ParquetTestData.testData.select('myboolean, 'mylong).collect() result.zipWithIndex.foreach { case (row, index) => { if (index % 3 == 0) @@ -256,7 +247,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA test("Creating case class RDD table") { TestSQLContext.sparkContext.parallelize((1 to 100)) .map(i => TestRDDEntry(i, s"val_$i")) - .registerAsTable("tmp") + .registerTempTable("tmp") val rdd = sql("SELECT * FROM tmp").collect().sortBy(_.getInt(0)) var counter = 1 rdd.foreach { @@ -275,7 +266,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA .map(i => TestRDDEntry(i, s"val_$i")) rdd.saveAsParquetFile(path) val readFile = parquetFile(path) - readFile.registerAsTable("tmpx") + readFile.registerTempTable("tmpx") val rdd_copy = sql("SELECT * FROM tmpx").collect() val rdd_orig = rdd.collect() for(i <- 0 to 99) { @@ -289,9 +280,9 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA val dirname = Utils.createTempDir() val source_rdd = TestSQLContext.sparkContext.parallelize((1 to 100)) .map(i => TestRDDEntry(i, s"val_$i")) - source_rdd.registerAsTable("source") + source_rdd.registerTempTable("source") val dest_rdd = createParquetFile[TestRDDEntry](dirname.toString) - dest_rdd.registerAsTable("dest") + dest_rdd.registerTempTable("dest") sql("INSERT OVERWRITE INTO dest SELECT * FROM source").collect() val rdd_copy1 = sql("SELECT * FROM dest").collect() assert(rdd_copy1.size === 100) @@ -556,7 +547,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA val data = nestedParserSqlContext .parquetFile(ParquetTestData.testNestedDir1.toString) .toSchemaRDD - data.registerAsTable("data") + data.registerTempTable("data") val query = nestedParserSqlContext.sql("SELECT owner, contacts[1].name FROM data") val tmp = query.collect() assert(tmp.size === 2) @@ -571,7 +562,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA val data = nestedParserSqlContext .parquetFile(ParquetTestData.testNestedDir2.toString) .toSchemaRDD - data.registerAsTable("data") + data.registerTempTable("data") val result1 = nestedParserSqlContext.sql("SELECT entries[0].value FROM data").collect() assert(result1.size === 1) assert(result1(0).size === 1) @@ -598,7 +589,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA val data = nestedParserSqlContext .parquetFile(ParquetTestData.testNestedDir3.toString) .toSchemaRDD - data.registerAsTable("data") + data.registerTempTable("data") val result1 = nestedParserSqlContext.sql("SELECT booleanNumberPairs[0].value[0].truth FROM data").collect() assert(result1.size === 1) assert(result1(0).size === 1) @@ -617,7 +608,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA val data = TestSQLContext .parquetFile(ParquetTestData.testNestedDir4.toString) .toSchemaRDD - data.registerAsTable("mapTable") + data.registerTempTable("mapTable") val result1 = sql("SELECT data1 FROM mapTable").collect() assert(result1.size === 1) assert(result1(0)(0) @@ -634,7 +625,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA val data = nestedParserSqlContext .parquetFile(ParquetTestData.testNestedDir4.toString) .toSchemaRDD - data.registerAsTable("mapTable") + data.registerTempTable("mapTable") val result1 = nestedParserSqlContext.sql("SELECT data2 FROM mapTable").collect() assert(result1.size === 1) val entry1 = result1(0)(0) @@ -667,7 +658,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA nestedParserSqlContext .parquetFile(tmpdir.toString) .toSchemaRDD - .registerAsTable("tmpcopy") + .registerTempTable("tmpcopy") val tmpdata = nestedParserSqlContext.sql("SELECT owner, contacts[1].name FROM tmpcopy").collect() assert(tmpdata.size === 2) assert(tmpdata(0).size === 2) @@ -688,7 +679,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA nestedParserSqlContext .parquetFile(tmpdir.toString) .toSchemaRDD - .registerAsTable("tmpmapcopy") + .registerTempTable("tmpmapcopy") val result1 = nestedParserSqlContext.sql("""SELECT data1["key2"] FROM tmpmapcopy""").collect() assert(result1.size === 1) assert(result1(0)(0) === 2) diff --git a/sql/hive-thriftserver/pom.xml b/sql/hive-thriftserver/pom.xml new file mode 100644 index 0000000000000..c6f60c18804a4 --- /dev/null +++ b/sql/hive-thriftserver/pom.xml @@ -0,0 +1,82 @@ + + + + + 4.0.0 + + org.apache.spark + spark-parent + 1.1.0-SNAPSHOT + ../../pom.xml + + + org.apache.spark + spark-hive-thriftserver_2.10 + jar + Spark Project Hive Thrift Server + http://spark.apache.org/ + + hive-thriftserver + + + + + org.apache.spark + spark-hive_${scala.binary.version} + ${project.version} + + + org.spark-project.hive + hive-cli + ${hive.version} + + + org.spark-project.hive + hive-jdbc + ${hive.version} + + + org.spark-project.hive + hive-beeline + ${hive.version} + + + org.scalatest + scalatest_${scala.binary.version} + test + + + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + + org.scalatest + scalatest-maven-plugin + + + org.apache.maven.plugins + maven-deploy-plugin + + true + + + + + diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala new file mode 100644 index 0000000000000..08d3f983d9e71 --- /dev/null +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver + +import scala.collection.JavaConversions._ + +import org.apache.commons.logging.LogFactory +import org.apache.hadoop.hive.conf.HiveConf +import org.apache.hadoop.hive.ql.session.SessionState +import org.apache.hive.service.cli.thrift.ThriftBinaryCLIService +import org.apache.hive.service.server.{HiveServer2, ServerOptionsProcessor} + +import org.apache.spark.Logging +import org.apache.spark.sql.hive.HiveContext +import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._ + +/** + * The main entry point for the Spark SQL port of HiveServer2. Starts up a `SparkSQLContext` and a + * `HiveThriftServer2` thrift server. + */ +private[hive] object HiveThriftServer2 extends Logging { + var LOG = LogFactory.getLog(classOf[HiveServer2]) + + def main(args: Array[String]) { + val optionsProcessor = new ServerOptionsProcessor("HiveThriftServer2") + + if (!optionsProcessor.process(args)) { + logWarning("Error starting HiveThriftServer2 with given arguments") + System.exit(-1) + } + + val ss = new SessionState(new HiveConf(classOf[SessionState])) + + // Set all properties specified via command line. + val hiveConf: HiveConf = ss.getConf + hiveConf.getAllProperties.toSeq.sortBy(_._1).foreach { case (k, v) => + logDebug(s"HiveConf var: $k=$v") + } + + SessionState.start(ss) + + logInfo("Starting SparkContext") + SparkSQLEnv.init() + SessionState.start(ss) + + Runtime.getRuntime.addShutdownHook( + new Thread() { + override def run() { + SparkSQLEnv.sparkContext.stop() + } + } + ) + + try { + val server = new HiveThriftServer2(SparkSQLEnv.hiveContext) + server.init(hiveConf) + server.start() + logInfo("HiveThriftServer2 started") + } catch { + case e: Exception => + logError("Error starting HiveThriftServer2", e) + System.exit(-1) + } + } +} + +private[hive] class HiveThriftServer2(hiveContext: HiveContext) + extends HiveServer2 + with ReflectedCompositeService { + + override def init(hiveConf: HiveConf) { + val sparkSqlCliService = new SparkSQLCLIService(hiveContext) + setSuperField(this, "cliService", sparkSqlCliService) + addService(sparkSqlCliService) + + val thriftCliService = new ThriftBinaryCLIService(sparkSqlCliService) + setSuperField(this, "thriftCLIService", thriftCliService) + addService(thriftCliService) + + initCompositeService(hiveConf) + } +} diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ReflectionUtils.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ReflectionUtils.scala new file mode 100644 index 0000000000000..599294dfbb7d7 --- /dev/null +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ReflectionUtils.scala @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver + +private[hive] object ReflectionUtils { + def setSuperField(obj : Object, fieldName: String, fieldValue: Object) { + setAncestorField(obj, 1, fieldName, fieldValue) + } + + def setAncestorField(obj: AnyRef, level: Int, fieldName: String, fieldValue: AnyRef) { + val ancestor = Iterator.iterate[Class[_]](obj.getClass)(_.getSuperclass).drop(level).next() + val field = ancestor.getDeclaredField(fieldName) + field.setAccessible(true) + field.set(obj, fieldValue) + } + + def getSuperField[T](obj: AnyRef, fieldName: String): T = { + getAncestorField[T](obj, 1, fieldName) + } + + def getAncestorField[T](clazz: Object, level: Int, fieldName: String): T = { + val ancestor = Iterator.iterate[Class[_]](clazz.getClass)(_.getSuperclass).drop(level).next() + val field = ancestor.getDeclaredField(fieldName) + field.setAccessible(true) + field.get(clazz).asInstanceOf[T] + } + + def invokeStatic(clazz: Class[_], methodName: String, args: (Class[_], AnyRef)*): AnyRef = { + invoke(clazz, null, methodName, args: _*) + } + + def invoke( + clazz: Class[_], + obj: AnyRef, + methodName: String, + args: (Class[_], AnyRef)*): AnyRef = { + + val (types, values) = args.unzip + val method = clazz.getDeclaredMethod(methodName, types: _*) + method.setAccessible(true) + method.invoke(obj, values.toSeq: _*) + } +} diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala new file mode 100755 index 0000000000000..4d0c506c5a397 --- /dev/null +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala @@ -0,0 +1,346 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver + +import scala.collection.JavaConversions._ + +import java.io._ +import java.util.{ArrayList => JArrayList} + +import jline.{ConsoleReader, History} +import org.apache.commons.lang.StringUtils +import org.apache.commons.logging.LogFactory +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.hive.cli.{CliDriver, CliSessionState, OptionsProcessor} +import org.apache.hadoop.hive.common.LogUtils.LogInitializationException +import org.apache.hadoop.hive.common.{HiveInterruptCallback, HiveInterruptUtils, LogUtils} +import org.apache.hadoop.hive.conf.HiveConf +import org.apache.hadoop.hive.ql.Driver +import org.apache.hadoop.hive.ql.exec.Utilities +import org.apache.hadoop.hive.ql.processors.{CommandProcessor, CommandProcessorFactory} +import org.apache.hadoop.hive.ql.session.SessionState +import org.apache.hadoop.hive.shims.ShimLoader +import org.apache.thrift.transport.TSocket + +import org.apache.spark.Logging + +private[hive] object SparkSQLCLIDriver { + private var prompt = "spark-sql" + private var continuedPrompt = "".padTo(prompt.length, ' ') + private var transport:TSocket = _ + + installSignalHandler() + + /** + * Install an interrupt callback to cancel all Spark jobs. In Hive's CliDriver#processLine(), + * a signal handler will invoke this registered callback if a Ctrl+C signal is detected while + * a command is being processed by the current thread. + */ + def installSignalHandler() { + HiveInterruptUtils.add(new HiveInterruptCallback { + override def interrupt() { + // Handle remote execution mode + if (SparkSQLEnv.sparkContext != null) { + SparkSQLEnv.sparkContext.cancelAllJobs() + } else { + if (transport != null) { + // Force closing of TCP connection upon session termination + transport.getSocket.close() + } + } + } + }) + } + + def main(args: Array[String]) { + val oproc = new OptionsProcessor() + if (!oproc.process_stage1(args)) { + System.exit(1) + } + + // NOTE: It is critical to do this here so that log4j is reinitialized + // before any of the other core hive classes are loaded + var logInitFailed = false + var logInitDetailMessage: String = null + try { + logInitDetailMessage = LogUtils.initHiveLog4j() + } catch { + case e: LogInitializationException => + logInitFailed = true + logInitDetailMessage = e.getMessage + } + + val sessionState = new CliSessionState(new HiveConf(classOf[SessionState])) + + sessionState.in = System.in + try { + sessionState.out = new PrintStream(System.out, true, "UTF-8") + sessionState.info = new PrintStream(System.err, true, "UTF-8") + sessionState.err = new PrintStream(System.err, true, "UTF-8") + } catch { + case e: UnsupportedEncodingException => System.exit(3) + } + + if (!oproc.process_stage2(sessionState)) { + System.exit(2) + } + + if (!sessionState.getIsSilent) { + if (logInitFailed) System.err.println(logInitDetailMessage) + else SessionState.getConsole.printInfo(logInitDetailMessage) + } + + // Set all properties specified via command line. + val conf: HiveConf = sessionState.getConf + sessionState.cmdProperties.entrySet().foreach { item: java.util.Map.Entry[Object, Object] => + conf.set(item.getKey.asInstanceOf[String], item.getValue.asInstanceOf[String]) + sessionState.getOverriddenConfigurations.put( + item.getKey.asInstanceOf[String], item.getValue.asInstanceOf[String]) + } + + SessionState.start(sessionState) + + // Clean up after we exit + Runtime.getRuntime.addShutdownHook( + new Thread() { + override def run() { + SparkSQLEnv.stop() + } + } + ) + + // "-h" option has been passed, so connect to Hive thrift server. + if (sessionState.getHost != null) { + sessionState.connect() + if (sessionState.isRemoteMode) { + prompt = s"[${sessionState.getHost}:${sessionState.getPort}]" + prompt + continuedPrompt = "".padTo(prompt.length, ' ') + } + } + + if (!sessionState.isRemoteMode && !ShimLoader.getHadoopShims.usesJobShell()) { + // Hadoop-20 and above - we need to augment classpath using hiveconf + // components. + // See also: code in ExecDriver.java + var loader = conf.getClassLoader + val auxJars = HiveConf.getVar(conf, HiveConf.ConfVars.HIVEAUXJARS) + if (StringUtils.isNotBlank(auxJars)) { + loader = Utilities.addToClassPath(loader, StringUtils.split(auxJars, ",")) + } + conf.setClassLoader(loader) + Thread.currentThread().setContextClassLoader(loader) + } + + val cli = new SparkSQLCLIDriver + cli.setHiveVariables(oproc.getHiveVariables) + + // TODO work around for set the log output to console, because the HiveContext + // will set the output into an invalid buffer. + sessionState.in = System.in + try { + sessionState.out = new PrintStream(System.out, true, "UTF-8") + sessionState.info = new PrintStream(System.err, true, "UTF-8") + sessionState.err = new PrintStream(System.err, true, "UTF-8") + } catch { + case e: UnsupportedEncodingException => System.exit(3) + } + + // Execute -i init files (always in silent mode) + cli.processInitFiles(sessionState) + + if (sessionState.execString != null) { + System.exit(cli.processLine(sessionState.execString)) + } + + try { + if (sessionState.fileName != null) { + System.exit(cli.processFile(sessionState.fileName)) + } + } catch { + case e: FileNotFoundException => + System.err.println(s"Could not open input file for reading. (${e.getMessage})") + System.exit(3) + } + + val reader = new ConsoleReader() + reader.setBellEnabled(false) + // reader.setDebug(new PrintWriter(new FileWriter("writer.debug", true))) + CliDriver.getCommandCompletor.foreach((e) => reader.addCompletor(e)) + + val historyDirectory = System.getProperty("user.home") + + try { + if (new File(historyDirectory).exists()) { + val historyFile = historyDirectory + File.separator + ".hivehistory" + reader.setHistory(new History(new File(historyFile))) + } else { + System.err.println("WARNING: Directory for Hive history file: " + historyDirectory + + " does not exist. History will not be available during this session.") + } + } catch { + case e: Exception => + System.err.println("WARNING: Encountered an error while trying to initialize Hive's " + + "history file. History will not be available during this session.") + System.err.println(e.getMessage) + } + + val clientTransportTSocketField = classOf[CliSessionState].getDeclaredField("transport") + clientTransportTSocketField.setAccessible(true) + + transport = clientTransportTSocketField.get(sessionState).asInstanceOf[TSocket] + + var ret = 0 + var prefix = "" + val currentDB = ReflectionUtils.invokeStatic(classOf[CliDriver], "getFormattedDb", + classOf[HiveConf] -> conf, classOf[CliSessionState] -> sessionState) + + def promptWithCurrentDB = s"$prompt$currentDB" + def continuedPromptWithDBSpaces = continuedPrompt + ReflectionUtils.invokeStatic( + classOf[CliDriver], "spacesForString", classOf[String] -> currentDB) + + var currentPrompt = promptWithCurrentDB + var line = reader.readLine(currentPrompt + "> ") + + while (line != null) { + if (prefix.nonEmpty) { + prefix += '\n' + } + + if (line.trim().endsWith(";") && !line.trim().endsWith("\\;")) { + line = prefix + line + ret = cli.processLine(line, true) + prefix = "" + currentPrompt = promptWithCurrentDB + } else { + prefix = prefix + line + currentPrompt = continuedPromptWithDBSpaces + } + + line = reader.readLine(currentPrompt + "> ") + } + + sessionState.close() + + System.exit(ret) + } +} + +private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { + private val sessionState = SessionState.get().asInstanceOf[CliSessionState] + + private val LOG = LogFactory.getLog("CliDriver") + + private val console = new SessionState.LogHelper(LOG) + + private val conf: Configuration = + if (sessionState != null) sessionState.getConf else new Configuration() + + // Force initializing SparkSQLEnv. This is put here but not object SparkSQLCliDriver + // because the Hive unit tests do not go through the main() code path. + if (!sessionState.isRemoteMode) { + SparkSQLEnv.init() + } + + override def processCmd(cmd: String): Int = { + val cmd_trimmed: String = cmd.trim() + val tokens: Array[String] = cmd_trimmed.split("\\s+") + val cmd_1: String = cmd_trimmed.substring(tokens(0).length()).trim() + if (cmd_trimmed.toLowerCase.equals("quit") || + cmd_trimmed.toLowerCase.equals("exit") || + tokens(0).equalsIgnoreCase("source") || + cmd_trimmed.startsWith("!") || + tokens(0).toLowerCase.equals("list") || + sessionState.isRemoteMode) { + val start = System.currentTimeMillis() + super.processCmd(cmd) + val end = System.currentTimeMillis() + val timeTaken: Double = (end - start) / 1000.0 + console.printInfo(s"Time taken: $timeTaken seconds") + 0 + } else { + var ret = 0 + val hconf = conf.asInstanceOf[HiveConf] + val proc: CommandProcessor = CommandProcessorFactory.get(tokens(0), hconf) + + if (proc != null) { + if (proc.isInstanceOf[Driver]) { + val driver = new SparkSQLDriver + + driver.init() + val out = sessionState.out + val start:Long = System.currentTimeMillis() + if (sessionState.getIsVerbose) { + out.println(cmd) + } + + val rc = driver.run(cmd) + ret = rc.getResponseCode + if (ret != 0) { + console.printError(rc.getErrorMessage()) + driver.close() + return ret + } + + val res = new JArrayList[String]() + + if (HiveConf.getBoolVar(conf, HiveConf.ConfVars.HIVE_CLI_PRINT_HEADER)) { + // Print the column names. + Option(driver.getSchema.getFieldSchemas).map { fields => + out.println(fields.map(_.getName).mkString("\t")) + } + } + + try { + while (!out.checkError() && driver.getResults(res)) { + res.foreach(out.println) + res.clear() + } + } catch { + case e:IOException => + console.printError( + s"""Failed with exception ${e.getClass.getName}: ${e.getMessage} + |${org.apache.hadoop.util.StringUtils.stringifyException(e)} + """.stripMargin) + ret = 1 + } + + val cret = driver.close() + if (ret == 0) { + ret = cret + } + + val end = System.currentTimeMillis() + if (end > start) { + val timeTaken:Double = (end - start) / 1000.0 + console.printInfo(s"Time taken: $timeTaken seconds", null) + } + + // Destroy the driver to release all the locks. + driver.destroy() + } else { + if (sessionState.getIsVerbose) { + sessionState.out.println(tokens(0) + " " + cmd_1) + } + ret = proc.run(cmd_1).getResponseCode + } + } + ret + } + } +} + diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala new file mode 100644 index 0000000000000..42cbf363b274f --- /dev/null +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver + +import scala.collection.JavaConversions._ + +import java.io.IOException +import java.util.{List => JList} +import javax.security.auth.login.LoginException + +import org.apache.commons.logging.Log +import org.apache.hadoop.hive.conf.HiveConf +import org.apache.hadoop.hive.shims.ShimLoader +import org.apache.hive.service.Service.STATE +import org.apache.hive.service.auth.HiveAuthFactory +import org.apache.hive.service.cli.CLIService +import org.apache.hive.service.{AbstractService, Service, ServiceException} + +import org.apache.spark.sql.hive.HiveContext +import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._ + +private[hive] class SparkSQLCLIService(hiveContext: HiveContext) + extends CLIService + with ReflectedCompositeService { + + override def init(hiveConf: HiveConf) { + setSuperField(this, "hiveConf", hiveConf) + + val sparkSqlSessionManager = new SparkSQLSessionManager(hiveContext) + setSuperField(this, "sessionManager", sparkSqlSessionManager) + addService(sparkSqlSessionManager) + + try { + HiveAuthFactory.loginFromKeytab(hiveConf) + val serverUserName = ShimLoader.getHadoopShims + .getShortUserName(ShimLoader.getHadoopShims.getUGIForConf(hiveConf)) + setSuperField(this, "serverUserName", serverUserName) + } catch { + case e @ (_: IOException | _: LoginException) => + throw new ServiceException("Unable to login to kerberos with given principal/keytab", e) + } + + initCompositeService(hiveConf) + } +} + +private[thriftserver] trait ReflectedCompositeService { this: AbstractService => + def initCompositeService(hiveConf: HiveConf) { + // Emulating `CompositeService.init(hiveConf)` + val serviceList = getAncestorField[JList[Service]](this, 2, "serviceList") + serviceList.foreach(_.init(hiveConf)) + + // Emulating `AbstractService.init(hiveConf)` + invoke(classOf[AbstractService], this, "ensureCurrentState", classOf[STATE] -> STATE.NOTINITED) + setAncestorField(this, 3, "hiveConf", hiveConf) + invoke(classOf[AbstractService], this, "changeState", classOf[STATE] -> STATE.INITED) + getAncestorField[Log](this, 3, "LOG").info(s"Service: $getName is inited.") + } +} diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala new file mode 100644 index 0000000000000..7463df1f47d43 --- /dev/null +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver + +import scala.collection.JavaConversions._ + +import java.util.{ArrayList => JArrayList} + +import org.apache.commons.lang.exception.ExceptionUtils +import org.apache.hadoop.hive.metastore.api.{FieldSchema, Schema} +import org.apache.hadoop.hive.ql.Driver +import org.apache.hadoop.hive.ql.processors.CommandProcessorResponse + +import org.apache.spark.Logging +import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes} + +private[hive] class SparkSQLDriver(val context: HiveContext = SparkSQLEnv.hiveContext) + extends Driver with Logging { + + private var tableSchema: Schema = _ + private var hiveResponse: Seq[String] = _ + + override def init(): Unit = { + } + + private def getResultSetSchema(query: context.QueryExecution): Schema = { + val analyzed = query.analyzed + logDebug(s"Result Schema: ${analyzed.output}") + if (analyzed.output.size == 0) { + new Schema(new FieldSchema("Response code", "string", "") :: Nil, null) + } else { + val fieldSchemas = analyzed.output.map { attr => + new FieldSchema(attr.name, HiveMetastoreTypes.toMetastoreType(attr.dataType), "") + } + + new Schema(fieldSchemas, null) + } + } + + override def run(command: String): CommandProcessorResponse = { + // TODO unify the error code + try { + val execution = context.executePlan(context.sql(command).logicalPlan) + hiveResponse = execution.stringResult() + tableSchema = getResultSetSchema(execution) + new CommandProcessorResponse(0) + } catch { + case cause: Throwable => + logError(s"Failed in [$command]", cause) + new CommandProcessorResponse(-3, ExceptionUtils.getFullStackTrace(cause), null) + } + } + + override def close(): Int = { + hiveResponse = null + tableSchema = null + 0 + } + + override def getSchema: Schema = tableSchema + + override def getResults(res: JArrayList[String]): Boolean = { + if (hiveResponse == null) { + false + } else { + res.addAll(hiveResponse) + hiveResponse = null + true + } + } + + override def destroy() { + super.destroy() + hiveResponse = null + tableSchema = null + } +} diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala new file mode 100644 index 0000000000000..582264eb59f83 --- /dev/null +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver + +import org.apache.hadoop.hive.ql.session.SessionState + +import org.apache.spark.scheduler.{SplitInfo, StatsReportListener} +import org.apache.spark.Logging +import org.apache.spark.sql.hive.HiveContext +import org.apache.spark.{SparkConf, SparkContext} + +/** A singleton object for the master program. The slaves should not access this. */ +private[hive] object SparkSQLEnv extends Logging { + logDebug("Initializing SparkSQLEnv") + + var hiveContext: HiveContext = _ + var sparkContext: SparkContext = _ + + def init() { + if (hiveContext == null) { + sparkContext = new SparkContext(new SparkConf() + .setAppName(s"SparkSQL::${java.net.InetAddress.getLocalHost.getHostName}")) + + sparkContext.addSparkListener(new StatsReportListener()) + + hiveContext = new HiveContext(sparkContext) { + @transient override lazy val sessionState = SessionState.get() + @transient override lazy val hiveconf = sessionState.getConf + } + } + } + + /** Cleans up and shuts down the Spark SQL environments. */ + def stop() { + logDebug("Shutting down Spark SQL Environment") + // Stop the SparkContext + if (SparkSQLEnv.sparkContext != null) { + sparkContext.stop() + sparkContext = null + hiveContext = null + } + } +} diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala new file mode 100644 index 0000000000000..6b3275b4eaf04 --- /dev/null +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver + +import java.util.concurrent.Executors + +import org.apache.commons.logging.Log +import org.apache.hadoop.hive.conf.HiveConf +import org.apache.hadoop.hive.conf.HiveConf.ConfVars +import org.apache.hive.service.cli.session.SessionManager + +import org.apache.spark.sql.hive.HiveContext +import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._ +import org.apache.spark.sql.hive.thriftserver.server.SparkSQLOperationManager + +private[hive] class SparkSQLSessionManager(hiveContext: HiveContext) + extends SessionManager + with ReflectedCompositeService { + + override def init(hiveConf: HiveConf) { + setSuperField(this, "hiveConf", hiveConf) + + val backgroundPoolSize = hiveConf.getIntVar(ConfVars.HIVE_SERVER2_ASYNC_EXEC_THREADS) + setSuperField(this, "backgroundOperationPool", Executors.newFixedThreadPool(backgroundPoolSize)) + getAncestorField[Log](this, 3, "LOG").info( + s"HiveServer2: Async execution pool size $backgroundPoolSize") + + val sparkSqlOperationManager = new SparkSQLOperationManager(hiveContext) + setSuperField(this, "operationManager", sparkSqlOperationManager) + addService(sparkSqlOperationManager) + + initCompositeService(hiveConf) + } +} diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala new file mode 100644 index 0000000000000..dee092159dd4c --- /dev/null +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver.server + +import scala.collection.JavaConversions._ +import scala.collection.mutable.ArrayBuffer +import scala.math.{random, round} + +import java.sql.Timestamp +import java.util.{Map => JMap} + +import org.apache.hadoop.hive.common.`type`.HiveDecimal +import org.apache.hadoop.hive.metastore.api.FieldSchema +import org.apache.hive.service.cli._ +import org.apache.hive.service.cli.operation.{ExecuteStatementOperation, Operation, OperationManager} +import org.apache.hive.service.cli.session.HiveSession + +import org.apache.spark.Logging +import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.hive.thriftserver.ReflectionUtils +import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes} +import org.apache.spark.sql.{SchemaRDD, Row => SparkRow} + +/** + * Executes queries using Spark SQL, and maintains a list of handles to active queries. + */ +class SparkSQLOperationManager(hiveContext: HiveContext) extends OperationManager with Logging { + val handleToOperation = ReflectionUtils + .getSuperField[JMap[OperationHandle, Operation]](this, "handleToOperation") + + override def newExecuteStatementOperation( + parentSession: HiveSession, + statement: String, + confOverlay: JMap[String, String], + async: Boolean): ExecuteStatementOperation = synchronized { + + val operation = new ExecuteStatementOperation(parentSession, statement, confOverlay) { + private var result: SchemaRDD = _ + private var iter: Iterator[SparkRow] = _ + private var dataTypes: Array[DataType] = _ + + def close(): Unit = { + // RDDs will be cleaned automatically upon garbage collection. + logDebug("CLOSING") + } + + def getNextRowSet(order: FetchOrientation, maxRowsL: Long): RowSet = { + if (!iter.hasNext) { + new RowSet() + } else { + val maxRows = maxRowsL.toInt // Do you really want a row batch larger than Int Max? No. + var curRow = 0 + var rowSet = new ArrayBuffer[Row](maxRows) + + while (curRow < maxRows && iter.hasNext) { + val sparkRow = iter.next() + val row = new Row() + var curCol = 0 + + while (curCol < sparkRow.length) { + dataTypes(curCol) match { + case StringType => + row.addString(sparkRow(curCol).asInstanceOf[String]) + case IntegerType => + row.addColumnValue(ColumnValue.intValue(sparkRow.getInt(curCol))) + case BooleanType => + row.addColumnValue(ColumnValue.booleanValue(sparkRow.getBoolean(curCol))) + case DoubleType => + row.addColumnValue(ColumnValue.doubleValue(sparkRow.getDouble(curCol))) + case FloatType => + row.addColumnValue(ColumnValue.floatValue(sparkRow.getFloat(curCol))) + case DecimalType => + val hiveDecimal = sparkRow.get(curCol).asInstanceOf[BigDecimal].bigDecimal + row.addColumnValue(ColumnValue.stringValue(new HiveDecimal(hiveDecimal))) + case LongType => + row.addColumnValue(ColumnValue.longValue(sparkRow.getLong(curCol))) + case ByteType => + row.addColumnValue(ColumnValue.byteValue(sparkRow.getByte(curCol))) + case ShortType => + row.addColumnValue(ColumnValue.intValue(sparkRow.getShort(curCol))) + case TimestampType => + row.addColumnValue( + ColumnValue.timestampValue(sparkRow.get(curCol).asInstanceOf[Timestamp])) + case BinaryType | _: ArrayType | _: StructType | _: MapType => + val hiveString = result + .queryExecution + .asInstanceOf[HiveContext#QueryExecution] + .toHiveString((sparkRow.get(curCol), dataTypes(curCol))) + row.addColumnValue(ColumnValue.stringValue(hiveString)) + } + curCol += 1 + } + rowSet += row + curRow += 1 + } + new RowSet(rowSet, 0) + } + } + + def getResultSetSchema: TableSchema = { + logWarning(s"Result Schema: ${result.queryExecution.analyzed.output}") + if (result.queryExecution.analyzed.output.size == 0) { + new TableSchema(new FieldSchema("Result", "string", "") :: Nil) + } else { + val schema = result.queryExecution.analyzed.output.map { attr => + new FieldSchema(attr.name, HiveMetastoreTypes.toMetastoreType(attr.dataType), "") + } + new TableSchema(schema) + } + } + + def run(): Unit = { + logInfo(s"Running query '$statement'") + setState(OperationState.RUNNING) + try { + result = hiveContext.sql(statement) + logDebug(result.queryExecution.toString()) + val groupId = round(random * 1000000).toString + hiveContext.sparkContext.setJobGroup(groupId, statement) + iter = result.queryExecution.toRdd.toLocalIterator + dataTypes = result.queryExecution.analyzed.output.map(_.dataType).toArray + setHasResultSet(true) + } catch { + // Actually do need to catch Throwable as some failures don't inherit from Exception and + // HiveServer will silently swallow them. + case e: Throwable => + logError("Error executing query:",e) + throw new HiveSQLException(e.toString) + } + setState(OperationState.FINISHED) + } + } + + handleToOperation.put(operation.getHandle, operation) + operation + } +} diff --git a/sql/hive-thriftserver/src/test/resources/data/files/small_kv.txt b/sql/hive-thriftserver/src/test/resources/data/files/small_kv.txt new file mode 100644 index 0000000000000..850f8014b6f05 --- /dev/null +++ b/sql/hive-thriftserver/src/test/resources/data/files/small_kv.txt @@ -0,0 +1,5 @@ +238val_238 +86val_86 +311val_311 +27val_27 +165val_165 diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala new file mode 100644 index 0000000000000..69f19f826a802 --- /dev/null +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver + +import java.io.{BufferedReader, InputStreamReader, PrintWriter} + +import org.scalatest.{BeforeAndAfterAll, FunSuite} + +class CliSuite extends FunSuite with BeforeAndAfterAll with TestUtils { + val WAREHOUSE_PATH = TestUtils.getWarehousePath("cli") + val METASTORE_PATH = TestUtils.getMetastorePath("cli") + + override def beforeAll() { + val pb = new ProcessBuilder( + "../../bin/spark-sql", + "--master", + "local", + "--hiveconf", + s"javax.jdo.option.ConnectionURL=jdbc:derby:;databaseName=$METASTORE_PATH;create=true", + "--hiveconf", + "hive.metastore.warehouse.dir=" + WAREHOUSE_PATH) + + process = pb.start() + outputWriter = new PrintWriter(process.getOutputStream, true) + inputReader = new BufferedReader(new InputStreamReader(process.getInputStream)) + errorReader = new BufferedReader(new InputStreamReader(process.getErrorStream)) + waitForOutput(inputReader, "spark-sql>") + } + + override def afterAll() { + process.destroy() + process.waitFor() + } + + test("simple commands") { + val dataFilePath = getDataFile("data/files/small_kv.txt") + executeQuery("create table hive_test1(key int, val string);") + executeQuery("load data local inpath '" + dataFilePath+ "' overwrite into table hive_test1;") + executeQuery("cache table hive_test1", "Time taken") + } +} diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala new file mode 100644 index 0000000000000..b7b7c9957ac34 --- /dev/null +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala @@ -0,0 +1,135 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver + +import scala.collection.JavaConversions._ +import scala.concurrent.ExecutionContext.Implicits.global +import scala.concurrent._ + +import java.io.{BufferedReader, InputStreamReader} +import java.net.ServerSocket +import java.sql.{Connection, DriverManager, Statement} + +import org.scalatest.{BeforeAndAfterAll, FunSuite} + +import org.apache.spark.Logging +import org.apache.spark.sql.catalyst.util.getTempFilePath + +/** + * Test for the HiveThriftServer2 using JDBC. + */ +class HiveThriftServer2Suite extends FunSuite with BeforeAndAfterAll with TestUtils with Logging { + + val WAREHOUSE_PATH = getTempFilePath("warehouse") + val METASTORE_PATH = getTempFilePath("metastore") + + val DRIVER_NAME = "org.apache.hive.jdbc.HiveDriver" + val TABLE = "test" + val HOST = "localhost" + val PORT = { + // Let the system to choose a random available port to avoid collision with other parallel + // builds. + val socket = new ServerSocket(0) + val port = socket.getLocalPort + socket.close() + port + } + + // If verbose is true, the test program will print all outputs coming from the Hive Thrift server. + val VERBOSE = Option(System.getenv("SPARK_SQL_TEST_VERBOSE")).getOrElse("false").toBoolean + + Class.forName(DRIVER_NAME) + + override def beforeAll() { launchServer() } + + override def afterAll() { stopServer() } + + private def launchServer(args: Seq[String] = Seq.empty) { + // Forking a new process to start the Hive Thrift server. The reason to do this is it is + // hard to clean up Hive resources entirely, so we just start a new process and kill + // that process for cleanup. + val defaultArgs = Seq( + "../../sbin/start-thriftserver.sh", + "--master local", + "--hiveconf", + "hive.root.logger=INFO,console", + "--hiveconf", + s"javax.jdo.option.ConnectionURL=jdbc:derby:;databaseName=$METASTORE_PATH;create=true", + "--hiveconf", + s"hive.metastore.warehouse.dir=$WAREHOUSE_PATH") + val pb = new ProcessBuilder(defaultArgs ++ args) + val environment = pb.environment() + environment.put("HIVE_SERVER2_THRIFT_PORT", PORT.toString) + environment.put("HIVE_SERVER2_THRIFT_BIND_HOST", HOST) + process = pb.start() + inputReader = new BufferedReader(new InputStreamReader(process.getInputStream)) + errorReader = new BufferedReader(new InputStreamReader(process.getErrorStream)) + waitForOutput(inputReader, "ThriftBinaryCLIService listening on") + + // Spawn a thread to read the output from the forked process. + // Note that this is necessary since in some configurations, log4j could be blocked + // if its output to stderr are not read, and eventually blocking the entire test suite. + future { + while (true) { + val stdout = readFrom(inputReader) + val stderr = readFrom(errorReader) + if (VERBOSE && stdout.length > 0) { + println(stdout) + } + if (VERBOSE && stderr.length > 0) { + println(stderr) + } + Thread.sleep(50) + } + } + } + + private def stopServer() { + process.destroy() + process.waitFor() + } + + test("test query execution against a Hive Thrift server") { + Thread.sleep(5 * 1000) + val dataFilePath = getDataFile("data/files/small_kv.txt") + val stmt = createStatement() + stmt.execute("DROP TABLE IF EXISTS test") + stmt.execute("DROP TABLE IF EXISTS test_cached") + stmt.execute("CREATE TABLE test(key int, val string)") + stmt.execute(s"LOAD DATA LOCAL INPATH '$dataFilePath' OVERWRITE INTO TABLE test") + stmt.execute("CREATE TABLE test_cached as select * from test limit 4") + stmt.execute("CACHE TABLE test_cached") + + var rs = stmt.executeQuery("select count(*) from test") + rs.next() + assert(rs.getInt(1) === 5) + + rs = stmt.executeQuery("select count(*) from test_cached") + rs.next() + assert(rs.getInt(1) === 4) + + stmt.close() + } + + def getConnection: Connection = { + val connectURI = s"jdbc:hive2://localhost:$PORT/" + DriverManager.getConnection(connectURI, System.getProperty("user.name"), "") + } + + def createStatement(): Statement = getConnection.createStatement() +} diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/TestUtils.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/TestUtils.scala new file mode 100644 index 0000000000000..bb2242618fbef --- /dev/null +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/TestUtils.scala @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver + +import java.io.{BufferedReader, PrintWriter} +import java.text.SimpleDateFormat +import java.util.Date + +import org.apache.hadoop.hive.common.LogUtils +import org.apache.hadoop.hive.common.LogUtils.LogInitializationException + +object TestUtils { + val timestamp = new SimpleDateFormat("yyyyMMdd-HHmmss") + + def getWarehousePath(prefix: String): String = { + System.getProperty("user.dir") + "/test_warehouses/" + prefix + "-warehouse-" + + timestamp.format(new Date) + } + + def getMetastorePath(prefix: String): String = { + System.getProperty("user.dir") + "/test_warehouses/" + prefix + "-metastore-" + + timestamp.format(new Date) + } + + // Dummy function for initialize the log4j properties. + def init() { } + + // initialize log4j + try { + LogUtils.initHiveLog4j() + } catch { + case e: LogInitializationException => // Ignore the error. + } +} + +trait TestUtils { + var process : Process = null + var outputWriter : PrintWriter = null + var inputReader : BufferedReader = null + var errorReader : BufferedReader = null + + def executeQuery( + cmd: String, outputMessage: String = "OK", timeout: Long = 15000): String = { + println("Executing: " + cmd + ", expecting output: " + outputMessage) + outputWriter.write(cmd + "\n") + outputWriter.flush() + waitForQuery(timeout, outputMessage) + } + + protected def waitForQuery(timeout: Long, message: String): String = { + if (waitForOutput(errorReader, message, timeout)) { + Thread.sleep(500) + readOutput() + } else { + assert(false, "Didn't find \"" + message + "\" in the output:\n" + readOutput()) + null + } + } + + // Wait for the specified str to appear in the output. + protected def waitForOutput( + reader: BufferedReader, str: String, timeout: Long = 10000): Boolean = { + val startTime = System.currentTimeMillis + var out = "" + while (!out.contains(str) && System.currentTimeMillis < (startTime + timeout)) { + out += readFrom(reader) + } + out.contains(str) + } + + // Read stdout output and filter out garbage collection messages. + protected def readOutput(): String = { + val output = readFrom(inputReader) + // Remove GC Messages + val filteredOutput = output.lines.filterNot(x => x.contains("[GC") || x.contains("[Full GC")) + .mkString("\n") + filteredOutput + } + + protected def readFrom(reader: BufferedReader): String = { + var out = "" + var c = 0 + while (reader.ready) { + c = reader.read() + out += c.asInstanceOf[Char] + } + out + } + + protected def getDataFile(name: String) = { + Thread.currentThread().getContextClassLoader.getResource(name) + } +} diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index c69e93ba2b9ba..4fef071161719 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -52,7 +52,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { override def blackList = Seq( // These tests use hooks that are not on the classpath and thus break all subsequent execution. "hook_order", - "hook_context", + "hook_context_cs", "mapjoin_hook", "multi_sahooks", "overridden_confs", @@ -289,7 +289,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "compute_stats_empty_table", "compute_stats_long", "compute_stats_string", - "compute_stats_table", "convert_enum_to_string", "correlationoptimizer1", "correlationoptimizer10", @@ -395,7 +394,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "groupby_sort_9", "groupby_sort_test_1", "having", - "having1", "implicit_cast1", "innerjoin", "inoutdriver", @@ -697,8 +695,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "udf7", "udf8", "udf9", - "udf_E", - "udf_PI", "udf_abs", "udf_acos", "udf_add", diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index 1699ffe06ce15..93d00f7c37c9b 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -32,7 +32,7 @@ Spark Project Hive http://spark.apache.org/ - hive + hive diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 201c85f3d501e..53f3dc11dbb9f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -23,29 +23,36 @@ import java.util.{ArrayList => JArrayList} import scala.collection.JavaConversions._ import scala.language.implicitConversions -import scala.reflect.runtime.universe.TypeTag +import scala.reflect.runtime.universe.{TypeTag, typeTag} +import org.apache.hadoop.fs.FileSystem +import org.apache.hadoop.fs.Path import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.ql.Driver +import org.apache.hadoop.hive.ql.metadata.Table import org.apache.hadoop.hive.ql.processors._ import org.apache.hadoop.hive.ql.session.SessionState +import org.apache.hadoop.hive.ql.stats.StatsSetupConst import org.apache.hadoop.hive.serde2.io.TimestampWritable import org.apache.spark.SparkContext import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.ScalaReflection -import org.apache.spark.sql.catalyst.analysis.{Analyzer, OverrideCatalog} +import org.apache.spark.sql.catalyst.analysis.{OverrideFunctionRegistry, Analyzer, OverrideCatalog} import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.execution.ExtractPythonUdfs import org.apache.spark.sql.execution.QueryExecutionException import org.apache.spark.sql.execution.{Command => PhysicalCommand} import org.apache.spark.sql.hive.execution.DescribeHiveTableCommand /** - * Starts up an instance of hive where metadata is stored locally. An in-process metadata data is - * created with data stored in ./metadata. Warehouse data is stored in in ./warehouse. + * DEPRECATED: Use HiveContext instead. */ +@deprecated(""" + Use HiveContext instead. It will still create a local metastore if one is not specified. + However, note that the default directory is ./metastore_db, not ./metastore + """, "1.1") class LocalHiveContext(sc: SparkContext) extends HiveContext(sc) { lazy val metastorePath = new File("metastore").getCanonicalPath @@ -53,9 +60,9 @@ class LocalHiveContext(sc: SparkContext) extends HiveContext(sc) { /** Sets up the system initially or after a RESET command */ protected def configure() { - set("javax.jdo.option.ConnectionURL", + setConf("javax.jdo.option.ConnectionURL", s"jdbc:derby:;databaseName=$metastorePath;create=true") - set("hive.metastore.warehouse.dir", warehousePath) + setConf("hive.metastore.warehouse.dir", warehousePath) } configure() // Must be called before initializing the catalog below. @@ -68,15 +75,29 @@ class LocalHiveContext(sc: SparkContext) extends HiveContext(sc) { class HiveContext(sc: SparkContext) extends SQLContext(sc) { self => + // Change the default SQL dialect to HiveQL + override private[spark] def dialect: String = getConf(SQLConf.DIALECT, "hiveql") + override protected[sql] def executePlan(plan: LogicalPlan): this.QueryExecution = new this.QueryExecution { val logical = plan } - /** - * Executes a query expressed in HiveQL using Spark, returning the result as a SchemaRDD. - */ + override def sql(sqlText: String): SchemaRDD = { + // TODO: Create a framework for registering parsers instead of just hardcoding if statements. + if (dialect == "sql") { + super.sql(sqlText) + } else if (dialect == "hiveql") { + new SchemaRDD(this, HiveQl.parseSql(sqlText)) + } else { + sys.error(s"Unsupported SQL dialect: $dialect. Try 'sql' or 'hiveql'") + } + } + + @deprecated("hiveql() is deprecated as the sql function now parses using HiveQL by default. " + + s"The SQL dialect for parsing can be set using ${SQLConf.DIALECT}", "1.1") def hiveql(hqlQuery: String): SchemaRDD = new SchemaRDD(this, HiveQl.parseSql(hqlQuery)) - /** An alias for `hiveql`. */ + @deprecated("hql() is deprecated as the sql function now parses using HiveQL by default. " + + s"The SQL dialect for parsing can be set using ${SQLConf.DIALECT}", "1.1") def hql(hqlQuery: String): SchemaRDD = hiveql(hqlQuery) /** @@ -90,9 +111,84 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { catalog.createTable("default", tableName, ScalaReflection.attributesFor[A], allowExisting) } + /** + * Analyzes the given table in the current database to generate statistics, which will be + * used in query optimizations. + * + * Right now, it only supports Hive tables and it only updates the size of a Hive table + * in the Hive metastore. + */ + def analyze(tableName: String) { + val relation = catalog.lookupRelation(None, tableName) match { + case LowerCaseSchema(r) => r + case o => o + } + + relation match { + case relation: MetastoreRelation => { + // This method is mainly based on + // org.apache.hadoop.hive.ql.stats.StatsUtils.getFileSizeForTable(HiveConf, Table) + // in Hive 0.13 (except that we do not use fs.getContentSummary). + // TODO: Generalize statistics collection. + // TODO: Why fs.getContentSummary returns wrong size on Jenkins? + // Can we use fs.getContentSummary in future? + // Seems fs.getContentSummary returns wrong table size on Jenkins. So we use + // countFileSize to count the table size. + def calculateTableSize(fs: FileSystem, path: Path): Long = { + val fileStatus = fs.getFileStatus(path) + val size = if (fileStatus.isDir) { + fs.listStatus(path).map(status => calculateTableSize(fs, status.getPath)).sum + } else { + fileStatus.getLen + } + + size + } + + def getFileSizeForTable(conf: HiveConf, table: Table): Long = { + val path = table.getPath() + var size: Long = 0L + try { + val fs = path.getFileSystem(conf) + size = calculateTableSize(fs, path) + } catch { + case e: Exception => + logWarning( + s"Failed to get the size of table ${table.getTableName} in the " + + s"database ${table.getDbName} because of ${e.toString}", e) + size = 0L + } + + size + } + + val tableParameters = relation.hiveQlTable.getParameters + val oldTotalSize = + Option(tableParameters.get(StatsSetupConst.TOTAL_SIZE)).map(_.toLong).getOrElse(0L) + val newTotalSize = getFileSizeForTable(hiveconf, relation.hiveQlTable) + // Update the Hive metastore if the total size of the table is different than the size + // recorded in the Hive metastore. + // This logic is based on org.apache.hadoop.hive.ql.exec.StatsTask.aggregateStats(). + if (newTotalSize > 0 && newTotalSize != oldTotalSize) { + tableParameters.put(StatsSetupConst.TOTAL_SIZE, newTotalSize.toString) + val hiveTTable = relation.hiveQlTable.getTTable + hiveTTable.setParameters(tableParameters) + val tableFullName = + relation.hiveQlTable.getDbName() + "." + relation.hiveQlTable.getTableName() + + catalog.client.alterTable(tableFullName, new Table(hiveTTable)) + } + } + case otherRelation => + throw new NotImplementedError( + s"Analyze has only implemented for Hive tables, " + + s"but ${tableName} is a ${otherRelation.nodeName}") + } + } + // Circular buffer to hold what hive prints to STDOUT and ERR. Only printed when failures occur. @transient - protected val outputBuffer = new java.io.OutputStream { + protected lazy val outputBuffer = new java.io.OutputStream { var pos: Int = 0 var buffer = new Array[Int](10240) def write(i: Int): Unit = { @@ -122,21 +218,21 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { /** * SQLConf and HiveConf contracts: when the hive session is first initialized, params in * HiveConf will get picked up by the SQLConf. Additionally, any properties set by - * set() or a SET command inside hql() or sql() will be set in the SQLConf *as well as* + * set() or a SET command inside sql() will be set in the SQLConf *as well as* * in the HiveConf. */ @transient protected[hive] lazy val hiveconf = new HiveConf(classOf[SessionState]) @transient protected[hive] lazy val sessionState = { val ss = new SessionState(hiveconf) - set(hiveconf.getAllProperties) // Have SQLConf pick up the initial set of HiveConf. + setConf(hiveconf.getAllProperties) // Have SQLConf pick up the initial set of HiveConf. ss } sessionState.err = new PrintStream(outputBuffer, true, "UTF-8") sessionState.out = new PrintStream(outputBuffer, true, "UTF-8") - override def set(key: String, value: String): Unit = { - super.set(key, value) + override def setConf(key: String, value: String): Unit = { + super.setConf(key, value) runSqlHive(s"SET $key=$value") } @@ -152,10 +248,14 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { } } + // Note that HiveUDFs will be overridden by functions registered in this context. + override protected[sql] lazy val functionRegistry = + new HiveFunctionRegistry with OverrideFunctionRegistry + /* An analyzer that uses the Hive metastore. */ @transient override protected[sql] lazy val analyzer = - new Analyzer(catalog, HiveFunctionRegistry, caseSensitive = false) + new Analyzer(catalog, functionRegistry, caseSensitive = false) /** * Runs the specified SQL query using Hive. @@ -204,7 +304,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { } } catch { case e: Exception => - logger.error( + logError( s""" |====================== |HIVE FAILURE OUTPUT @@ -231,7 +331,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { HiveTableScans, DataSinks, Scripts, - PartialAggregation, + HashAggregation, LeftSemiJoin, HashJoin, BasicOperators, @@ -247,7 +347,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { protected[sql] abstract class QueryExecution extends super.QueryExecution { // TODO: Create mixin for the analyzer instead of overriding things here. override lazy val optimizedPlan = - optimizer(catalog.PreInsertionCasts(catalog.CreateTables(analyzed))) + optimizer(ExtractPythonUdfs(catalog.PreInsertionCasts(catalog.CreateTables(analyzed)))) override lazy val toRdd: RDD[Row] = executedPlan.execute().map(_.copy()) @@ -255,14 +355,14 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { Seq(StringType, IntegerType, LongType, DoubleType, FloatType, BooleanType, ByteType, ShortType, DecimalType, TimestampType, BinaryType) - protected def toHiveString(a: (Any, DataType)): String = a match { + protected[sql] def toHiveString(a: (Any, DataType)): String = a match { case (struct: Row, StructType(fields)) => struct.zip(fields).map { case (v, t) => s""""${t.name}":${toHiveStructString(v, t.dataType)}""" }.mkString("{", ",", "}") - case (seq: Seq[_], ArrayType(typ)) => + case (seq: Seq[_], ArrayType(typ, _)) => seq.map(v => (v, typ)).map(toHiveStructString).mkString("[", ",", "]") - case (map: Map[_,_], MapType(kType, vType)) => + case (map: Map[_,_], MapType(kType, vType, _)) => map.map { case (key, value) => toHiveStructString((key, kType)) + ":" + toHiveStructString((value, vType)) @@ -279,9 +379,9 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { struct.zip(fields).map { case (v, t) => s""""${t.name}":${toHiveStructString(v, t.dataType)}""" }.mkString("{", ",", "}") - case (seq: Seq[_], ArrayType(typ)) => + case (seq: Seq[_], ArrayType(typ, _)) => seq.map(v => (v, typ)).map(toHiveStructString).mkString("[", ",", "]") - case (map: Map[_,_], MapType(kType, vType)) => + case (map: Map[_,_], MapType(kType, vType, _)) => map.map { case (key, value) => toHiveStructString((key, kType)) + ":" + toHiveStructString((value, vType)) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index ad7dc0ecdb1bf..354fcd53f303b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -152,8 +152,9 @@ private[hive] trait HiveInspectors { } def toInspector(dataType: DataType): ObjectInspector = dataType match { - case ArrayType(tpe) => ObjectInspectorFactory.getStandardListObjectInspector(toInspector(tpe)) - case MapType(keyType, valueType) => + case ArrayType(tpe, _) => + ObjectInspectorFactory.getStandardListObjectInspector(toInspector(tpe)) + case MapType(keyType, valueType, _) => ObjectInspectorFactory.getStandardMapObjectInspector( toInspector(keyType), toInspector(valueType)) case StringType => PrimitiveObjectInspectorFactory.javaStringObjectInspector diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 156b090712df2..301cf51c00e2b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -19,15 +19,18 @@ package org.apache.spark.sql.hive import scala.util.parsing.combinator.RegexParsers +import org.apache.hadoop.fs.Path +import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.metastore.api.{FieldSchema, StorageDescriptor, SerDeInfo} import org.apache.hadoop.hive.metastore.api.{Table => TTable, Partition => TPartition} import org.apache.hadoop.hive.ql.metadata.{Hive, Partition, Table} import org.apache.hadoop.hive.ql.plan.TableDesc -import org.apache.hadoop.hive.ql.session.SessionState +import org.apache.hadoop.hive.ql.stats.StatsSetupConst import org.apache.hadoop.hive.serde2.Deserializer import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.sql.Logging +import org.apache.spark.Logging +import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.analysis.{EliminateAnalysisOperators, Catalog} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical @@ -64,9 +67,8 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with // Since HiveQL is case insensitive for table names we make them all lowercase. MetastoreRelation( - databaseName, - tblName, - alias)(table.getTTable, partitions.map(part => part.getTPartition)) + databaseName, tblName, alias)( + table.getTTable, partitions.map(part => part.getTPartition))(hive) } def createTable( @@ -200,7 +202,9 @@ object HiveMetastoreTypes extends RegexParsers { "varchar\\((\\d+)\\)".r ^^^ StringType protected lazy val arrayType: Parser[DataType] = - "array" ~> "<" ~> dataType <~ ">" ^^ ArrayType + "array" ~> "<" ~> dataType <~ ">" ^^ { + case tpe => ArrayType(tpe) + } protected lazy val mapType: Parser[DataType] = "map" ~> "<" ~> dataType ~ "," ~ dataType <~ ">" ^^ { @@ -229,10 +233,10 @@ object HiveMetastoreTypes extends RegexParsers { } def toMetastoreType(dt: DataType): String = dt match { - case ArrayType(elementType) => s"array<${toMetastoreType(elementType)}>" + case ArrayType(elementType, _) => s"array<${toMetastoreType(elementType)}>" case StructType(fields) => s"struct<${fields.map(f => s"${f.name}:${toMetastoreType(f.dataType)}").mkString(",")}>" - case MapType(keyType, valueType) => + case MapType(keyType, valueType, _) => s"map<${toMetastoreType(keyType)},${toMetastoreType(valueType)}>" case StringType => "string" case FloatType => "float" @@ -251,7 +255,11 @@ object HiveMetastoreTypes extends RegexParsers { private[hive] case class MetastoreRelation (databaseName: String, tableName: String, alias: Option[String]) (val table: TTable, val partitions: Seq[TPartition]) - extends BaseRelation { + (@transient sqlContext: SQLContext) + extends LeafNode { + + self: Product => + // TODO: Can we use org.apache.hadoop.hive.ql.metadata.Table as the type of table and // use org.apache.hadoop.hive.ql.metadata.Partition as the type of elements of partitions. // Right now, using org.apache.hadoop.hive.ql.metadata.Table and @@ -264,6 +272,21 @@ private[hive] case class MetastoreRelation new Partition(hiveQlTable, p) } + @transient override lazy val statistics = Statistics( + sizeInBytes = { + // TODO: check if this estimate is valid for tables after partition pruning. + // NOTE: getting `totalSize` directly from params is kind of hacky, but this should be + // relatively cheap if parameters for the table are populated into the metastore. An + // alternative would be going through Hadoop's FileSystem API, which can be expensive if a lot + // of RPCs are involved. Besides `totalSize`, there are also `numFiles`, `numRows`, + // `rawDataSize` keys (see StatsSetupConst in Hive) that we can look at in the future. + BigInt( + Option(hiveQlTable.getParameters.get(StatsSetupConst.TOTAL_SIZE)) + .map(_.toLong) + .getOrElse(sqlContext.defaultSizeInBytes)) + } + ) + val tableDesc = new TableDesc( Class.forName(hiveQlTable.getSerializationLib).asInstanceOf[Class[Deserializer]], hiveQlTable.getInputFormatClass, @@ -275,14 +298,14 @@ private[hive] case class MetastoreRelation hiveQlTable.getMetadata ) - implicit class SchemaAttribute(f: FieldSchema) { - def toAttribute = AttributeReference( - f.getName, - HiveMetastoreTypes.toDataType(f.getType), - // Since data can be dumped in randomly with no validation, everything is nullable. - nullable = true - )(qualifiers = tableName +: alias.toSeq) - } + implicit class SchemaAttribute(f: FieldSchema) { + def toAttribute = AttributeReference( + f.getName, + HiveMetastoreTypes.toDataType(f.getType), + // Since data can be dumped in randomly with no validation, everything is nullable. + nullable = true + )(qualifiers = tableName +: alias.toSeq) + } // Must be a stable value since new attributes are born here. val partitionKeys = hiveQlTable.getPartitionKeys.map(_.toAttribute) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index e6ab68b563f8d..bc2fefafd58c8 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -44,6 +44,8 @@ private[hive] case class SourceCommand(filePath: String) extends Command private[hive] case class AddFile(filePath: String) extends Command +private[hive] case class DropTable(tableName: String, ifExists: Boolean) extends Command + /** Provides a mapping from HiveQL statements to catalyst logical plans and expression trees. */ private[hive] object HiveQl { protected val nativeCommands = Seq( @@ -96,7 +98,6 @@ private[hive] object HiveQl { "TOK_CREATEINDEX", "TOK_DROPDATABASE", "TOK_DROPINDEX", - "TOK_DROPTABLE", "TOK_MSCK", // TODO(marmbrus): Figure out how view are expanded by hive, as we might need to handle this. @@ -296,8 +297,11 @@ private[hive] object HiveQl { matches.headOption } - assert(remainingNodes.isEmpty, - s"Unhandled clauses: ${remainingNodes.map(dumpTree(_)).mkString("\n")}") + if (remainingNodes.nonEmpty) { + sys.error( + s"""Unhandled clauses: ${remainingNodes.map(dumpTree(_)).mkString("\n")}. + |You are likely trying to use an unsupported Hive feature."""".stripMargin) + } clauses } @@ -377,6 +381,12 @@ private[hive] object HiveQl { } protected def nodeToPlan(node: Node): LogicalPlan = node match { + // Special drop table that also uncaches. + case Token("TOK_DROPTABLE", + Token("TOK_TABNAME", tableNameParts) :: + ifExists) => + val tableName = tableNameParts.map { case Token(p, Nil) => p }.mkString(".") + DropTable(tableName, ifExists.nonEmpty) // Just fake explain for any of the native commands. case Token("TOK_EXPLAIN", explainArgs) if noExplainCommands.contains(explainArgs.head.getText) => @@ -610,7 +620,7 @@ private[hive] object HiveQl { // TOK_DESTINATION means to overwrite the table. val resultDestination = (intoClause orElse destClause).getOrElse(sys.error("No destination found.")) - val overwrite = if (intoClause.isEmpty) true else false + val overwrite = intoClause.isEmpty nodeToDest( resultDestination, withLimit, @@ -741,7 +751,10 @@ private[hive] object HiveQl { case Token(allJoinTokens(joinToken), relation1 :: relation2 :: other) => - assert(other.size <= 1, s"Unhandled join child $other") + if (!(other.size <= 1)) { + sys.error(s"Unsupported join operation: $other") + } + val joinType = joinToken match { case "TOK_JOIN" => Inner case "TOK_RIGHTOUTERJOIN" => RightOuter @@ -749,7 +762,6 @@ private[hive] object HiveQl { case "TOK_FULLOUTERJOIN" => FullOuter case "TOK_LEFTSEMIJOIN" => LeftSemi } - assert(other.size <= 1, "Unhandled join clauses.") Join(nodeToRelation(relation1), nodeToRelation(relation2), joinType, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index 4d0fab4140b21..2175c5f3835a6 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -81,6 +81,8 @@ private[hive] trait HiveStrategies { case logical.NativeCommand(sql) => NativeCommand(sql, plan.output)(context) :: Nil + case DropTable(tableName, ifExists) => execution.DropTable(tableName, ifExists) :: Nil + case describe: logical.DescribeCommand => val resolvedTable = context.executePlan(describe.table).analyzed resolvedTable match { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala index c3942578d6b5a..82c88280d7754 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala @@ -24,6 +24,8 @@ import org.apache.hadoop.hive.ql.exec.Utilities import org.apache.hadoop.hive.ql.metadata.{Partition => HivePartition, Table => HiveTable} import org.apache.hadoop.hive.ql.plan.TableDesc import org.apache.hadoop.hive.serde2.Deserializer +import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector + import org.apache.hadoop.io.Writable import org.apache.hadoop.mapred.{FileInputFormat, InputFormat, JobConf} @@ -31,13 +33,16 @@ import org.apache.spark.SerializableWritable import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.{EmptyRDD, HadoopRDD, RDD, UnionRDD} +import org.apache.spark.sql.catalyst.expressions.{Attribute, Row, GenericMutableRow, Literal, Cast} +import org.apache.spark.sql.catalyst.types.DataType + /** * A trait for subclasses that handle table scans. */ private[hive] sealed trait TableReader { - def makeRDDForTable(hiveTable: HiveTable): RDD[_] + def makeRDDForTable(hiveTable: HiveTable): RDD[Row] - def makeRDDForPartitionedTable(partitions: Seq[HivePartition]): RDD[_] + def makeRDDForPartitionedTable(partitions: Seq[HivePartition]): RDD[Row] } @@ -46,7 +51,10 @@ private[hive] sealed trait TableReader { * data warehouse directory. */ private[hive] -class HadoopTableReader(@transient _tableDesc: TableDesc, @transient sc: HiveContext) +class HadoopTableReader( + @transient attributes: Seq[Attribute], + @transient relation: MetastoreRelation, + @transient sc: HiveContext) extends TableReader { // Choose the minimum number of splits. If mapred.map.tasks is set, then use that unless @@ -63,10 +71,10 @@ class HadoopTableReader(@transient _tableDesc: TableDesc, @transient sc: HiveCon def hiveConf = _broadcastedHiveConf.value.value - override def makeRDDForTable(hiveTable: HiveTable): RDD[_] = + override def makeRDDForTable(hiveTable: HiveTable): RDD[Row] = makeRDDForTable( hiveTable, - _tableDesc.getDeserializerClass.asInstanceOf[Class[Deserializer]], + relation.tableDesc.getDeserializerClass.asInstanceOf[Class[Deserializer]], filterOpt = None) /** @@ -81,14 +89,14 @@ class HadoopTableReader(@transient _tableDesc: TableDesc, @transient sc: HiveCon def makeRDDForTable( hiveTable: HiveTable, deserializerClass: Class[_ <: Deserializer], - filterOpt: Option[PathFilter]): RDD[_] = { + filterOpt: Option[PathFilter]): RDD[Row] = { assert(!hiveTable.isPartitioned, """makeRDDForTable() cannot be called on a partitioned table, since input formats may differ across partitions. Use makeRDDForTablePartitions() instead.""") // Create local references to member variables, so that the entire `this` object won't be // serialized in the closure below. - val tableDesc = _tableDesc + val tableDesc = relation.tableDesc val broadcastedHiveConf = _broadcastedHiveConf val tablePath = hiveTable.getPath @@ -99,23 +107,20 @@ class HadoopTableReader(@transient _tableDesc: TableDesc, @transient sc: HiveCon .asInstanceOf[java.lang.Class[InputFormat[Writable, Writable]]] val hadoopRDD = createHadoopRdd(tableDesc, inputPathStr, ifc) + val attrsWithIndex = attributes.zipWithIndex + val mutableRow = new GenericMutableRow(attrsWithIndex.length) val deserializedHadoopRDD = hadoopRDD.mapPartitions { iter => val hconf = broadcastedHiveConf.value.value val deserializer = deserializerClass.newInstance() deserializer.initialize(hconf, tableDesc.getProperties) - // Deserialize each Writable to get the row value. - iter.map { - case v: Writable => deserializer.deserialize(v) - case value => - sys.error(s"Unable to deserialize non-Writable: $value of ${value.getClass.getName}") - } + HadoopTableReader.fillObject(iter, deserializer, attrsWithIndex, mutableRow) } deserializedHadoopRDD } - override def makeRDDForPartitionedTable(partitions: Seq[HivePartition]): RDD[_] = { + override def makeRDDForPartitionedTable(partitions: Seq[HivePartition]): RDD[Row] = { val partitionToDeserializer = partitions.map(part => (part, part.getDeserializer.getClass.asInstanceOf[Class[Deserializer]])).toMap makeRDDForPartitionedTable(partitionToDeserializer, filterOpt = None) @@ -132,9 +137,9 @@ class HadoopTableReader(@transient _tableDesc: TableDesc, @transient sc: HiveCon * subdirectory of each partition being read. If None, then all files are accepted. */ def makeRDDForPartitionedTable( - partitionToDeserializer: Map[HivePartition, Class[_ <: Deserializer]], - filterOpt: Option[PathFilter]): RDD[_] = { - + partitionToDeserializer: Map[HivePartition, + Class[_ <: Deserializer]], + filterOpt: Option[PathFilter]): RDD[Row] = { val hivePartitionRDDs = partitionToDeserializer.map { case (partition, partDeserializer) => val partDesc = Utilities.getPartitionDesc(partition) val partPath = partition.getPartitionPath @@ -156,33 +161,42 @@ class HadoopTableReader(@transient _tableDesc: TableDesc, @transient sc: HiveCon } // Create local references so that the outer object isn't serialized. - val tableDesc = _tableDesc + val tableDesc = relation.tableDesc val broadcastedHiveConf = _broadcastedHiveConf val localDeserializer = partDeserializer + val mutableRow = new GenericMutableRow(attributes.length) + + // split the attributes (output schema) into 2 categories: + // (partition keys, ordinal), (normal attributes, ordinal), the ordinal mean the + // index of the attribute in the output Row. + val (partitionKeys, attrs) = attributes.zipWithIndex.partition(attr => { + relation.partitionKeys.indexOf(attr._1) >= 0 + }) + + def fillPartitionKeys(parts: Array[String], row: GenericMutableRow) = { + partitionKeys.foreach { case (attr, ordinal) => + // get partition key ordinal for a given attribute + val partOridinal = relation.partitionKeys.indexOf(attr) + row(ordinal) = Cast(Literal(parts(partOridinal)), attr.dataType).eval(null) + } + } + // fill the partition key for the given MutableRow Object + fillPartitionKeys(partValues, mutableRow) val hivePartitionRDD = createHadoopRdd(tableDesc, inputPathStr, ifc) hivePartitionRDD.mapPartitions { iter => val hconf = broadcastedHiveConf.value.value - val rowWithPartArr = new Array[Object](2) - - // The update and deserializer initialization are intentionally - // kept out of the below iter.map loop to save performance. - rowWithPartArr.update(1, partValues) val deserializer = localDeserializer.newInstance() deserializer.initialize(hconf, partProps) - // Map each tuple to a row object - iter.map { value => - val deserializedRow = deserializer.deserialize(value) - rowWithPartArr.update(0, deserializedRow) - rowWithPartArr.asInstanceOf[Object] - } + // fill the non partition key attributes + HadoopTableReader.fillObject(iter, deserializer, attrs, mutableRow) } }.toSeq // Even if we don't use any partitions, we still need an empty RDD if (hivePartitionRDDs.size == 0) { - new EmptyRDD[Object](sc.sparkContext) + new EmptyRDD[Row](sc.sparkContext) } else { new UnionRDD(hivePartitionRDDs(0).context, hivePartitionRDDs) } @@ -225,10 +239,9 @@ class HadoopTableReader(@transient _tableDesc: TableDesc, @transient sc: HiveCon // Only take the value (skip the key) because Hive works only with values. rdd.map(_._2) } - } -private[hive] object HadoopTableReader { +private[hive] object HadoopTableReader extends HiveInspectors { /** * Curried. After given an argument for 'path', the resulting JobConf => Unit closure is used to * instantiate a HadoopRDD. @@ -241,4 +254,40 @@ private[hive] object HadoopTableReader { val bufferSize = System.getProperty("spark.buffer.size", "65536") jobConf.set("io.file.buffer.size", bufferSize) } + + /** + * Transform the raw data(Writable object) into the Row object for an iterable input + * @param iter Iterable input which represented as Writable object + * @param deserializer Deserializer associated with the input writable object + * @param attrs Represents the row attribute names and its zero-based position in the MutableRow + * @param row reusable MutableRow object + * + * @return Iterable Row object that transformed from the given iterable input. + */ + def fillObject( + iter: Iterator[Writable], + deserializer: Deserializer, + attrs: Seq[(Attribute, Int)], + row: GenericMutableRow): Iterator[Row] = { + val soi = deserializer.getObjectInspector().asInstanceOf[StructObjectInspector] + // get the field references according to the attributes(output of the reader) required + val fieldRefs = attrs.map { case (attr, idx) => (soi.getStructFieldRef(attr.name), idx) } + + // Map each tuple to a row object + iter.map { value => + val raw = deserializer.deserialize(value) + var idx = 0; + while (idx < fieldRefs.length) { + val fieldRef = fieldRefs(idx)._1 + val fieldIdx = fieldRefs(idx)._2 + val fieldValue = soi.getStructFieldData(raw, fieldRef) + + row(fieldIdx) = unwrapData(fieldValue, fieldRef.getFieldObjectInspector()) + + idx += 1 + } + + row: Row + } + } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala index 9386008d02d51..d890df866fbe5 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala @@ -53,15 +53,24 @@ object TestHive * hive metastore seems to lead to weird non-deterministic failures. Therefore, the execution of * test cases that rely on TestHive must be serialized. */ -class TestHiveContext(sc: SparkContext) extends LocalHiveContext(sc) { +class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { self => // By clearing the port we force Spark to pick a new one. This allows us to rerun tests // without restarting the JVM. System.clearProperty("spark.hostPort") - override lazy val warehousePath = getTempFilePath("sparkHiveWarehouse").getCanonicalPath - override lazy val metastorePath = getTempFilePath("sparkHiveMetastore").getCanonicalPath + lazy val warehousePath = getTempFilePath("sparkHiveWarehouse").getCanonicalPath + lazy val metastorePath = getTempFilePath("sparkHiveMetastore").getCanonicalPath + + /** Sets up the system initially or after a RESET command */ + protected def configure() { + setConf("javax.jdo.option.ConnectionURL", + s"jdbc:derby:;databaseName=$metastorePath;create=true") + setConf("hive.metastore.warehouse.dir", warehousePath) + } + + configure() // Must be called before initializing the catalog below. /** The location of the compiled hive distribution */ lazy val hiveHome = envVarToFile("HIVE_HOME") @@ -139,7 +148,7 @@ class TestHiveContext(sc: SparkContext) extends LocalHiveContext(sc) { describedTables ++ logical.collect { case UnresolvedRelation(databaseName, name, _) => name } val referencedTestTables = referencedTables.filter(testTables.contains) - logger.debug(s"Query references test tables: ${referencedTestTables.mkString(", ")}") + logDebug(s"Query references test tables: ${referencedTestTables.mkString(", ")}") referencedTestTables.foreach(loadTestTable) // Proceed with analysis. analyzer(logical) @@ -264,7 +273,7 @@ class TestHiveContext(sc: SparkContext) extends LocalHiveContext(sc) { if (!(loadedTables contains name)) { // Marks the table as loaded first to prevent infite mutually recursive table loading. loadedTables += name - logger.info(s"Loading test table $name") + logInfo(s"Loading test table $name") val createCmds = testTables.get(name).map(_.commands).getOrElse(sys.error(s"Unknown test table $name")) createCmds.foreach(_()) @@ -288,8 +297,8 @@ class TestHiveContext(sc: SparkContext) extends LocalHiveContext(sc) { def reset() { try { // HACK: Hive is too noisy by default. - org.apache.log4j.LogManager.getCurrentLoggers.foreach { logger => - logger.asInstanceOf[org.apache.log4j.Logger].setLevel(org.apache.log4j.Level.WARN) + org.apache.log4j.LogManager.getCurrentLoggers.foreach { log => + log.asInstanceOf[org.apache.log4j.Logger].setLevel(org.apache.log4j.Level.WARN) } // It is important that we RESET first as broken hooks that might have been set could break @@ -303,7 +312,7 @@ class TestHiveContext(sc: SparkContext) extends LocalHiveContext(sc) { loadedTables.clear() catalog.client.getAllTables("default").foreach { t => - logger.debug(s"Deleting table $t") + logDebug(s"Deleting table $t") val table = catalog.client.getTable("default", t) catalog.client.getIndexes("default", t, 255).foreach { index => @@ -316,7 +325,7 @@ class TestHiveContext(sc: SparkContext) extends LocalHiveContext(sc) { } catalog.client.getAllDatabases.filterNot(_ == "default").foreach { db => - logger.debug(s"Dropping Database: $db") + logDebug(s"Dropping Database: $db") catalog.client.dropDatabase(db, true, false, true) } @@ -338,7 +347,7 @@ class TestHiveContext(sc: SparkContext) extends LocalHiveContext(sc) { loadTestTable("srcpart") } catch { case e: Exception => - logger.error(s"FATAL ERROR: Failed to reset TestDB state. $e") + logError(s"FATAL ERROR: Failed to reset TestDB state. $e") // At this point there is really no reason to continue, but the test framework traps exits. // So instead we just pause forever so that at least the developer can see where things // started to go wrong. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/api/java/JavaHiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/api/java/JavaHiveContext.scala index c9ee162191c96..a201d2349a2ef 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/api/java/JavaHiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/api/java/JavaHiveContext.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.hive.api.java import org.apache.spark.api.java.JavaSparkContext import org.apache.spark.sql.api.java.{JavaSQLContext, JavaSchemaRDD} +import org.apache.spark.sql.SQLConf import org.apache.spark.sql.hive.{HiveContext, HiveQl} /** @@ -28,9 +29,21 @@ class JavaHiveContext(sparkContext: JavaSparkContext) extends JavaSQLContext(spa override val sqlContext = new HiveContext(sparkContext) + override def sql(sqlText: String): JavaSchemaRDD = { + // TODO: Create a framework for registering parsers instead of just hardcoding if statements. + if (sqlContext.dialect == "sql") { + super.sql(sqlText) + } else if (sqlContext.dialect == "hiveql") { + new JavaSchemaRDD(sqlContext, HiveQl.parseSql(sqlText)) + } else { + sys.error(s"Unsupported SQL dialect: ${sqlContext.dialect}. Try 'sql' or 'hiveql'") + } + } + /** - * Executes a query expressed in HiveQL, returning the result as a JavaSchemaRDD. + * DEPRECATED: Use sql(...) Instead */ + @Deprecated def hql(hqlQuery: String): JavaSchemaRDD = new JavaSchemaRDD(sqlContext, HiveQl.parseSql(hqlQuery)) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DropTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DropTable.scala new file mode 100644 index 0000000000000..9cd0c86c6c796 --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DropTable.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.expressions.Row +import org.apache.spark.sql.execution.{Command, LeafNode} +import org.apache.spark.sql.hive.HiveContext + +/** + * :: DeveloperApi :: + * Drops a table from the metastore and removes it if it is cached. + */ +@DeveloperApi +case class DropTable(tableName: String, ifExists: Boolean) extends LeafNode with Command { + + def hiveContext = sqlContext.asInstanceOf[HiveContext] + + def output = Seq.empty + + override protected[sql] lazy val sideEffectResult: Seq[Any] = { + val ifExistsClause = if (ifExists) "IF EXISTS " else "" + hiveContext.runSqlHive(s"DROP TABLE $ifExistsClause$tableName") + hiveContext.catalog.unregisterTable(None, tableName) + Seq.empty + } + + override def execute(): RDD[Row] = { + sideEffectResult + sparkContext.emptyRDD[Row] + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala index e7016fa16eea9..8920e2a76a27f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala @@ -34,7 +34,6 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.types.{BooleanType, DataType} import org.apache.spark.sql.execution._ import org.apache.spark.sql.hive._ -import org.apache.spark.util.MutablePair /** * :: DeveloperApi :: @@ -50,8 +49,7 @@ case class HiveTableScan( relation: MetastoreRelation, partitionPruningPred: Option[Expression])( @transient val context: HiveContext) - extends LeafNode - with HiveInspectors { + extends LeafNode { require(partitionPruningPred.isEmpty || relation.hiveQlTable.isPartitioned, "Partition pruning predicates only supported for partitioned tables.") @@ -67,42 +65,7 @@ case class HiveTableScan( } @transient - private[this] val hadoopReader = new HadoopTableReader(relation.tableDesc, context) - - /** - * The hive object inspector for this table, which can be used to extract values from the - * serialized row representation. - */ - @transient - private[this] lazy val objectInspector = - relation.tableDesc.getDeserializer.getObjectInspector.asInstanceOf[StructObjectInspector] - - /** - * Functions that extract the requested attributes from the hive output. Partitioned values are - * casted from string to its declared data type. - */ - @transient - protected lazy val attributeFunctions: Seq[(Any, Array[String]) => Any] = { - attributes.map { a => - val ordinal = relation.partitionKeys.indexOf(a) - if (ordinal >= 0) { - val dataType = relation.partitionKeys(ordinal).dataType - (_: Any, partitionKeys: Array[String]) => { - castFromString(partitionKeys(ordinal), dataType) - } - } else { - val ref = objectInspector.getAllStructFieldRefs - .find(_.getFieldName == a.name) - .getOrElse(sys.error(s"Can't find attribute $a")) - val fieldObjectInspector = ref.getFieldObjectInspector - - (row: Any, _: Array[String]) => { - val data = objectInspector.getStructFieldData(row, ref) - unwrapData(data, fieldObjectInspector) - } - } - } - } + private[this] val hadoopReader = new HadoopTableReader(attributes, relation, context) private[this] def castFromString(value: String, dataType: DataType) = { Cast(Literal(value), dataType).eval(null) @@ -114,6 +77,7 @@ case class HiveTableScan( val columnInternalNames = neededColumnIDs.map(HiveConf.getColumnInternalName(_)).mkString(",") if (attributes.size == relation.output.size) { + // SQLContext#pruneFilterProject guarantees no duplicated value in `attributes` ColumnProjectionUtils.setFullyReadColumns(hiveConf) } else { ColumnProjectionUtils.appendReadColumnIDs(hiveConf, neededColumnIDs) @@ -140,12 +104,6 @@ case class HiveTableScan( addColumnMetadataToConf(context.hiveconf) - private def inputRdd = if (!relation.hiveQlTable.isPartitioned) { - hadoopReader.makeRDDForTable(relation.hiveQlTable) - } else { - hadoopReader.makeRDDForPartitionedTable(prunePartitions(relation.hiveQlPartitions)) - } - /** * Prunes partitions not involve the query plan. * @@ -169,44 +127,10 @@ case class HiveTableScan( } } - override def execute() = { - inputRdd.mapPartitions { iterator => - if (iterator.isEmpty) { - Iterator.empty - } else { - val mutableRow = new GenericMutableRow(attributes.length) - val mutablePair = new MutablePair[Any, Array[String]]() - val buffered = iterator.buffered - - // NOTE (lian): Critical path of Hive table scan, unnecessary FP style code and pattern - // matching are avoided intentionally. - val rowsAndPartitionKeys = buffered.head match { - // With partition keys - case _: Array[Any] => - buffered.map { case array: Array[Any] => - val deserializedRow = array(0) - val partitionKeys = array(1).asInstanceOf[Array[String]] - mutablePair.update(deserializedRow, partitionKeys) - } - - // Without partition keys - case _ => - val emptyPartitionKeys = Array.empty[String] - buffered.map { deserializedRow => - mutablePair.update(deserializedRow, emptyPartitionKeys) - } - } - - rowsAndPartitionKeys.map { pair => - var i = 0 - while (i < attributes.length) { - mutableRow(i) = attributeFunctions(i)(pair._1, pair._2) - i += 1 - } - mutableRow: Row - } - } - } + override def execute() = if (!relation.hiveQlTable.isPartitioned) { + hadoopReader.makeRDDForTable(relation.hiveQlTable) + } else { + hadoopReader.makeRDDForPartitionedTable(prunePartitions(relation.hiveQlPartitions)) } override def output = attributes diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index c2b0b00aa5852..39033bdeac4b0 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -131,7 +131,7 @@ case class InsertIntoHiveTable( conf, SparkHiveHadoopWriter.createPathFromString(fileSinkConf.getDirName, conf)) - logger.debug("Saving as hadoop file of type " + valueClass.getSimpleName) + log.debug("Saving as hadoop file of type " + valueClass.getSimpleName) val writer = new SparkHiveHadoopWriter(conf, fileSinkConf) writer.preSetup() diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala index 8258ee5fef0eb..0c8f676e9c5c8 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala @@ -67,7 +67,7 @@ case class ScriptTransformation( } } readerThread.start() - val outputProjection = new Projection(input) + val outputProjection = new InterpretedProjection(input, child.output) iter .map(outputProjection) // TODO: Use SerDe diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala index 057eb60a02612..179aac5cbd5cd 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala @@ -25,7 +25,7 @@ import org.apache.hadoop.hive.ql.exec.{FunctionInfo, FunctionRegistry} import org.apache.hadoop.hive.ql.udf.{UDFType => HiveUDFType} import org.apache.hadoop.hive.ql.udf.generic._ -import org.apache.spark.sql.Logging +import org.apache.spark.Logging import org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.types._ @@ -34,7 +34,8 @@ import org.apache.spark.util.Utils.getContextOrSparkClassLoader /* Implicit conversions */ import scala.collection.JavaConversions._ -private[hive] object HiveFunctionRegistry extends analysis.FunctionRegistry with HiveInspectors { +private[hive] abstract class HiveFunctionRegistry + extends analysis.FunctionRegistry with HiveInspectors { def getFunctionInfo(name: String) = FunctionRegistry.getFunctionInfo(name) @@ -92,9 +93,8 @@ private[hive] abstract class HiveUdf extends Expression with Logging with HiveFu } private[hive] case class HiveSimpleUdf(functionClassName: String, children: Seq[Expression]) - extends HiveUdf { + extends HiveUdf with HiveInspectors { - import org.apache.spark.sql.hive.HiveFunctionRegistry._ type UDFType = UDF @transient @@ -119,7 +119,7 @@ private[hive] case class HiveSimpleUdf(functionClassName: String, children: Seq[ sys.error(s"No matching wrapper found, options: ${argClass.getConstructors.toSeq}.")) (a: Any) => { - logger.debug( + logDebug( s"Wrapping $a of type ${if (a == null) "null" else a.getClass.getName} using $constructor.") // We must make sure that primitives get boxed java style. if (a == null) { @@ -251,8 +251,10 @@ private[hive] case class HiveGenericUdtf( @transient protected lazy val function: GenericUDTF = createFunction() + @transient protected lazy val inputInspectors = children.map(_.dataType).map(toInspector) + @transient protected lazy val outputInspectors = { val structInspector = function.initialize(inputInspectors.toArray) structInspector.getAllStructFieldRefs.map(_.getFieldObjectInspector) @@ -278,7 +280,7 @@ private[hive] case class HiveGenericUdtf( override def eval(input: Row): TraversableOnce[Row] = { outputInspectors // Make sure initialized. - val inputProjection = new Projection(children) + val inputProjection = new InterpretedProjection(children) val collector = new UDTFCollector function.setCollector(collector) @@ -332,7 +334,7 @@ private[hive] case class HiveUdafFunction( override def eval(input: Row): Any = unwrapData(function.evaluate(buffer), returnInspector) @transient - val inputProjection = new Projection(exprs) + val inputProjection = new InterpretedProjection(exprs) def update(input: Row): Unit = { val inputs = inputProjection(input).asInstanceOf[Seq[AnyRef]].toArray diff --git a/sql/hive/src/test/resources/golden/case else null-0-8ef2f741400830ef889a9dd0c817fe3d b/sql/hive/src/test/resources/golden/case else null-0-8ef2f741400830ef889a9dd0c817fe3d new file mode 100644 index 0000000000000..00750edc07d64 --- /dev/null +++ b/sql/hive/src/test/resources/golden/case else null-0-8ef2f741400830ef889a9dd0c817fe3d @@ -0,0 +1 @@ +3 diff --git a/sql/hive/src/test/resources/golden/double case-0-f513687d17dcb18546fefa75000a52f2 b/sql/hive/src/test/resources/golden/double case-0-f513687d17dcb18546fefa75000a52f2 new file mode 100644 index 0000000000000..00750edc07d64 --- /dev/null +++ b/sql/hive/src/test/resources/golden/double case-0-f513687d17dcb18546fefa75000a52f2 @@ -0,0 +1 @@ +3 diff --git a/sql/hive/src/test/resources/golden/having no references-0-d2de3ba23759d25ef77cdfbab72cbb63 b/sql/hive/src/test/resources/golden/having no references-0-d2de3ba23759d25ef77cdfbab72cbb63 new file mode 100644 index 0000000000000..3f2cab688ccc2 --- /dev/null +++ b/sql/hive/src/test/resources/golden/having no references-0-d2de3ba23759d25ef77cdfbab72cbb63 @@ -0,0 +1,136 @@ +0 +5 +12 +15 +18 +24 +26 +35 +37 +42 +51 +58 +67 +70 +72 +76 +83 +84 +90 +95 +97 +98 +100 +103 +104 +113 +118 +119 +120 +125 +128 +129 +134 +137 +138 +146 +149 +152 +164 +165 +167 +169 +172 +174 +175 +176 +179 +187 +191 +193 +195 +197 +199 +200 +203 +205 +207 +208 +209 +213 +216 +217 +219 +221 +223 +224 +229 +230 +233 +237 +238 +239 +242 +255 +256 +265 +272 +273 +277 +278 +280 +281 +282 +288 +298 +307 +309 +311 +316 +317 +318 +321 +322 +325 +327 +331 +333 +342 +344 +348 +353 +367 +369 +382 +384 +395 +396 +397 +399 +401 +403 +404 +406 +409 +413 +414 +417 +424 +429 +430 +431 +438 +439 +454 +458 +459 +462 +463 +466 +468 +469 +478 +480 +489 +492 +498 diff --git a/sql/hive/src/test/resources/golden/null case-0-581cdfe70091e546414b202da2cebdcb b/sql/hive/src/test/resources/golden/null case-0-581cdfe70091e546414b202da2cebdcb new file mode 100644 index 0000000000000..d00491fd7e5bb --- /dev/null +++ b/sql/hive/src/test/resources/golden/null case-0-581cdfe70091e546414b202da2cebdcb @@ -0,0 +1 @@ +1 diff --git a/sql/hive/src/test/resources/golden/partition_based_table_scan_with_different_serde-0-1436cccda63b78dd6e43a399da6cc474 b/sql/hive/src/test/resources/golden/partition_based_table_scan_with_different_serde-0-1436cccda63b78dd6e43a399da6cc474 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/partition_based_table_scan_with_different_serde-1-8d9bf54373f45bc35f8cb6e82771b154 b/sql/hive/src/test/resources/golden/partition_based_table_scan_with_different_serde-1-8d9bf54373f45bc35f8cb6e82771b154 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/partition_based_table_scan_with_different_serde-2-7816c17905012cf381abf93d230faa8d b/sql/hive/src/test/resources/golden/partition_based_table_scan_with_different_serde-2-7816c17905012cf381abf93d230faa8d new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/partition_based_table_scan_with_different_serde-3-90089a6db3c3d8ee5ff5ea6b9153b3cc b/sql/hive/src/test/resources/golden/partition_based_table_scan_with_different_serde-3-90089a6db3c3d8ee5ff5ea6b9153b3cc new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/partition_based_table_scan_with_different_serde-4-8caed2a6e80250a6d38a59388679c298 b/sql/hive/src/test/resources/golden/partition_based_table_scan_with_different_serde-4-8caed2a6e80250a6d38a59388679c298 new file mode 100644 index 0000000000000..f369f21e1833f --- /dev/null +++ b/sql/hive/src/test/resources/golden/partition_based_table_scan_with_different_serde-4-8caed2a6e80250a6d38a59388679c298 @@ -0,0 +1,2 @@ +100 100 2010-01-01 +200 200 2010-01-02 diff --git a/sql/hive/src/test/resources/golden/single case-0-c264e319c52f1840a32959d552b99e73 b/sql/hive/src/test/resources/golden/single case-0-c264e319c52f1840a32959d552b99e73 new file mode 100644 index 0000000000000..d00491fd7e5bb --- /dev/null +++ b/sql/hive/src/test/resources/golden/single case-0-c264e319c52f1840a32959d552b99e73 @@ -0,0 +1 @@ +1 diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/QueryTest.scala index 11d8b1f0a3d96..95921c3d7ae09 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -51,9 +51,9 @@ class QueryTest extends FunSuite { fail( s""" |Exception thrown while executing query: - |${rdd.logicalPlan} + |${rdd.queryExecution} |== Exception == - |$e + |${stackTraceToString(e)} """.stripMargin) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala index 3132d0112c708..188579edd7bdd 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala @@ -23,6 +23,8 @@ import org.apache.spark.sql.hive.execution.HiveComparisonTest import org.apache.spark.sql.hive.test.TestHive class CachedTableSuite extends HiveComparisonTest { + import TestHive._ + TestHive.loadTestTable("src") test("cache table") { @@ -32,6 +34,20 @@ class CachedTableSuite extends HiveComparisonTest { createQueryTest("read from cached table", "SELECT * FROM src LIMIT 1", reset = false) + test("Drop cached table") { + sql("CREATE TABLE test(a INT)") + cacheTable("test") + sql("SELECT * FROM test").collect() + sql("DROP TABLE test") + intercept[org.apache.hadoop.hive.ql.metadata.InvalidTableException] { + sql("SELECT * FROM test").collect() + } + } + + test("DROP nonexistant table") { + sql("DROP TABLE IF EXISTS nonexistantTable") + } + test("check that table is cached and uncache") { TestHive.table("src").queryExecution.analyzed match { case _ : InMemoryRelation => // Found evidence of caching @@ -58,14 +74,14 @@ class CachedTableSuite extends HiveComparisonTest { } test("'CACHE TABLE' and 'UNCACHE TABLE' HiveQL statement") { - TestHive.hql("CACHE TABLE src") + TestHive.sql("CACHE TABLE src") TestHive.table("src").queryExecution.executedPlan match { case _: InMemoryColumnarTableScan => // Found evidence of caching case _ => fail(s"Table 'src' should be cached") } assert(TestHive.isCached("src"), "Table 'src' should be cached") - TestHive.hql("UNCACHE TABLE src") + TestHive.sql("UNCACHE TABLE src") TestHive.table("src").queryExecution.executedPlan match { case _: InMemoryColumnarTableScan => fail(s"Table 'src' should not be cached") case _ => // Found evidence of uncaching diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala index 833f3502154f3..7e323146f9da2 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala @@ -28,7 +28,7 @@ case class TestData(key: Int, value: String) class InsertIntoHiveTableSuite extends QueryTest { val testData = TestHive.sparkContext.parallelize( (1 to 100).map(i => TestData(i, i.toString))) - testData.registerAsTable("testData") + testData.registerTempTable("testData") test("insertInto() HiveTable") { createTable[TestData]("createAndInsertTest") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala new file mode 100644 index 0000000000000..bf5931bbf97ee --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import scala.reflect.ClassTag + +import org.apache.spark.sql.{SQLConf, QueryTest} +import org.apache.spark.sql.execution.{BroadcastHashJoin, ShuffledHashJoin} +import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.hive.test.TestHive._ + +class StatisticsSuite extends QueryTest { + + test("analyze MetastoreRelations") { + def queryTotalSize(tableName: String): BigInt = + catalog.lookupRelation(None, tableName).statistics.sizeInBytes + + // Non-partitioned table + sql("CREATE TABLE analyzeTable (key STRING, value STRING)").collect() + sql("INSERT INTO TABLE analyzeTable SELECT * FROM src").collect() + sql("INSERT INTO TABLE analyzeTable SELECT * FROM src").collect() + + assert(queryTotalSize("analyzeTable") === defaultSizeInBytes) + + analyze("analyzeTable") + + assert(queryTotalSize("analyzeTable") === BigInt(11624)) + + sql("DROP TABLE analyzeTable").collect() + + // Partitioned table + sql( + """ + |CREATE TABLE analyzeTable_part (key STRING, value STRING) PARTITIONED BY (ds STRING) + """.stripMargin).collect() + sql( + """ + |INSERT INTO TABLE analyzeTable_part PARTITION (ds='2010-01-01') + |SELECT * FROM src + """.stripMargin).collect() + sql( + """ + |INSERT INTO TABLE analyzeTable_part PARTITION (ds='2010-01-02') + |SELECT * FROM src + """.stripMargin).collect() + sql( + """ + |INSERT INTO TABLE analyzeTable_part PARTITION (ds='2010-01-03') + |SELECT * FROM src + """.stripMargin).collect() + + assert(queryTotalSize("analyzeTable_part") === defaultSizeInBytes) + + analyze("analyzeTable_part") + + assert(queryTotalSize("analyzeTable_part") === BigInt(17436)) + + sql("DROP TABLE analyzeTable_part").collect() + + // Try to analyze a temp table + sql("""SELECT * FROM src""").registerTempTable("tempTable") + intercept[NotImplementedError] { + analyze("tempTable") + } + catalog.unregisterTable(None, "tempTable") + } + + test("estimates the size of a test MetastoreRelation") { + val rdd = sql("""SELECT * FROM src""") + val sizes = rdd.queryExecution.analyzed.collect { case mr: MetastoreRelation => + mr.statistics.sizeInBytes + } + assert(sizes.size === 1) + assert(sizes(0).equals(BigInt(5812)), + s"expected exact size 5812 for test table 'src', got: ${sizes(0)}") + } + + test("auto converts to broadcast hash join, by size estimate of a relation") { + def mkTest( + before: () => Unit, + after: () => Unit, + query: String, + expectedAnswer: Seq[Any], + ct: ClassTag[_]) = { + before() + + var rdd = sql(query) + + // Assert src has a size smaller than the threshold. + val sizes = rdd.queryExecution.analyzed.collect { + case r if ct.runtimeClass.isAssignableFrom(r.getClass) => r.statistics.sizeInBytes + } + assert(sizes.size === 2 && sizes(0) <= autoBroadcastJoinThreshold, + s"query should contain two relations, each of which has size smaller than autoConvertSize") + + // Using `sparkPlan` because for relevant patterns in HashJoin to be + // matched, other strategies need to be applied. + var bhj = rdd.queryExecution.sparkPlan.collect { case j: BroadcastHashJoin => j } + assert(bhj.size === 1, + s"actual query plans do not contain broadcast join: ${rdd.queryExecution}") + + checkAnswer(rdd, expectedAnswer) // check correctness of output + + TestHive.settings.synchronized { + val tmp = autoBroadcastJoinThreshold + + sql(s"""SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=-1""") + rdd = sql(query) + bhj = rdd.queryExecution.sparkPlan.collect { case j: BroadcastHashJoin => j } + assert(bhj.isEmpty, "BroadcastHashJoin still planned even though it is switched off") + + val shj = rdd.queryExecution.sparkPlan.collect { case j: ShuffledHashJoin => j } + assert(shj.size === 1, + "ShuffledHashJoin should be planned when BroadcastHashJoin is turned off") + + sql(s"""SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=$tmp""") + } + + after() + } + + /** Tests for MetastoreRelation */ + val metastoreQuery = """SELECT * FROM src a JOIN src b ON a.key = 238 AND a.key = b.key""" + val metastoreAnswer = Seq.fill(4)((238, "val_238", 238, "val_238")) + mkTest( + () => (), + () => (), + metastoreQuery, + metastoreAnswer, + implicitly[ClassTag[MetastoreRelation]] + ) + } + +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/api/java/JavaHiveQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/api/java/JavaHiveQLSuite.scala index 10c8069a624e6..9644b707eb1a0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/api/java/JavaHiveQLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/api/java/JavaHiveQLSuite.scala @@ -40,7 +40,7 @@ class JavaHiveQLSuite extends FunSuite { ignore("SELECT * FROM src") { assert( - javaHiveCtx.hql("SELECT * FROM src").collect().map(_.getInt(0)) === + javaHiveCtx.sql("SELECT * FROM src").collect().map(_.getInt(0)) === TestHive.sql("SELECT * FROM src").collect().map(_.getInt(0)).toSeq) } @@ -56,33 +56,34 @@ class JavaHiveQLSuite extends FunSuite { val tableName = "test_native_commands" assertResult(0) { - javaHiveCtx.hql(s"DROP TABLE IF EXISTS $tableName").count() + javaHiveCtx.sql(s"DROP TABLE IF EXISTS $tableName").count() } assertResult(0) { - javaHiveCtx.hql(s"CREATE TABLE $tableName(key INT, value STRING)").count() + javaHiveCtx.sql(s"CREATE TABLE $tableName(key INT, value STRING)").count() } - javaHiveCtx.hql("SHOW TABLES").registerAsTable("show_tables") + javaHiveCtx.sql("SHOW TABLES").registerTempTable("show_tables") assert( javaHiveCtx - .hql("SELECT result FROM show_tables") + .sql("SELECT result FROM show_tables") .collect() .map(_.getString(0)) .contains(tableName)) assertResult(Array(Array("key", "int", "None"), Array("value", "string", "None"))) { - javaHiveCtx.hql(s"DESCRIBE $tableName").registerAsTable("describe_table") + javaHiveCtx.sql(s"DESCRIBE $tableName").registerTempTable("describe_table") + javaHiveCtx - .hql("SELECT result FROM describe_table") + .sql("SELECT result FROM describe_table") .collect() .map(_.getString(0).split("\t").map(_.trim)) .toArray } - assert(isExplanation(javaHiveCtx.hql( + assert(isExplanation(javaHiveCtx.sql( s"EXPLAIN SELECT key, COUNT(*) FROM $tableName GROUP BY key"))) TestHive.reset() @@ -90,7 +91,7 @@ class JavaHiveQLSuite extends FunSuite { ignore("Exactly once semantics for DDL and command statements") { val tableName = "test_exactly_once" - val q0 = javaHiveCtx.hql(s"CREATE TABLE $tableName(key INT, value STRING)") + val q0 = javaHiveCtx.sql(s"CREATE TABLE $tableName(key INT, value STRING)") // If the table was not created, the following assertion would fail assert(Try(TestHive.table(tableName)).isSuccess) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index b4dbf2b115799..0ebaf6ffd5458 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -21,7 +21,7 @@ import java.io._ import org.scalatest.{BeforeAndAfterAll, FunSuite, GivenWhenThen} -import org.apache.spark.sql.Logging +import org.apache.spark.Logging import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.logical.{NativeCommand => LogicalNativeCommand} @@ -132,7 +132,7 @@ abstract class HiveComparisonTest answer: Seq[String]): Seq[String] = { def isSorted(plan: LogicalPlan): Boolean = plan match { - case _: Join | _: Aggregate | _: BaseRelation | _: Generate | _: Sample | _: Distinct => false + case _: Join | _: Aggregate | _: Generate | _: Sample | _: Distinct => false case PhysicalOperation(_, _, Sort(_, _)) => true case _ => plan.children.iterator.exists(isSorted) } @@ -197,7 +197,7 @@ abstract class HiveComparisonTest // If test sharding is enable, skip tests that are not in the correct shard. shardInfo.foreach { case (shardId, numShards) if testCaseName.hashCode % numShards != shardId => return - case (shardId, _) => logger.debug(s"Shard $shardId includes test '$testCaseName'") + case (shardId, _) => logDebug(s"Shard $shardId includes test '$testCaseName'") } // Skip tests found in directories specified by user. @@ -213,13 +213,13 @@ abstract class HiveComparisonTest .map(new File(_, testCaseName)) .filter(_.exists) if (runOnlyDirectories.nonEmpty && runIndicators.isEmpty) { - logger.debug( + logDebug( s"Skipping test '$testCaseName' not found in ${runOnlyDirectories.map(_.getCanonicalPath)}") return } test(testCaseName) { - logger.debug(s"=== HIVE TEST: $testCaseName ===") + logDebug(s"=== HIVE TEST: $testCaseName ===") // Clear old output for this testcase. outputDirectories.map(new File(_, testCaseName)).filter(_.exists()).foreach(_.delete()) @@ -235,19 +235,19 @@ abstract class HiveComparisonTest .filterNot(_ contains "hive.outerjoin.supports.filters") if (allQueries != queryList) - logger.warn(s"Simplifications made on unsupported operations for test $testCaseName") + logWarning(s"Simplifications made on unsupported operations for test $testCaseName") lazy val consoleTestCase = { val quotes = "\"\"\"" queryList.zipWithIndex.map { case (query, i) => - s"""val q$i = hql($quotes$query$quotes); q$i.collect()""" + s"""val q$i = sql($quotes$query$quotes); q$i.collect()""" }.mkString("\n== Console version of this test ==\n", "\n", "\n") } try { // MINOR HACK: You must run a query before calling reset the first time. - TestHive.hql("SHOW TABLES") + TestHive.sql("SHOW TABLES") if (reset) { TestHive.reset() } val hiveCacheFiles = queryList.zipWithIndex.map { @@ -257,11 +257,11 @@ abstract class HiveComparisonTest } val hiveCachedResults = hiveCacheFiles.flatMap { cachedAnswerFile => - logger.debug(s"Looking for cached answer file $cachedAnswerFile.") + logDebug(s"Looking for cached answer file $cachedAnswerFile.") if (cachedAnswerFile.exists) { Some(fileToString(cachedAnswerFile)) } else { - logger.debug(s"File $cachedAnswerFile not found") + logDebug(s"File $cachedAnswerFile not found") None } }.map { @@ -272,7 +272,7 @@ abstract class HiveComparisonTest val hiveResults: Seq[Seq[String]] = if (hiveCachedResults.size == queryList.size) { - logger.info(s"Using answer cache for test: $testCaseName") + logInfo(s"Using answer cache for test: $testCaseName") hiveCachedResults } else { @@ -287,7 +287,7 @@ abstract class HiveComparisonTest if (installHooksCommand.findAllMatchIn(queryString).nonEmpty) sys.error("hive exec hooks not supported for tests.") - logger.warn(s"Running query ${i+1}/${queryList.size} with hive.") + logWarning(s"Running query ${i+1}/${queryList.size} with hive.") // Analyze the query with catalyst to ensure test tables are loaded. val answer = hiveQuery.analyzed match { case _: ExplainCommand => Nil // No need to execute EXPLAIN queries as we don't check the output. @@ -351,7 +351,7 @@ abstract class HiveComparisonTest val resultComparison = sideBySide(hivePrintOut, catalystPrintOut).mkString("\n") if (recomputeCache) { - logger.warn(s"Clearing cache files for failed test $testCaseName") + logWarning(s"Clearing cache files for failed test $testCaseName") hiveCacheFiles.foreach(_.delete()) } @@ -380,7 +380,7 @@ abstract class HiveComparisonTest TestHive.runSqlHive("SELECT key FROM src") } catch { case e: Exception => - logger.error(s"FATAL ERROR: Canary query threw $e This implies that the testing environment has likely been corrupted.") + logError(s"FATAL ERROR: Canary query threw $e This implies that the testing environment has likely been corrupted.") // The testing setup traps exits so wait here for a long time so the developer can see when things started // to go wrong. Thread.sleep(1000000) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQueryFileTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQueryFileTest.scala index 50ab71a9003d3..02518d516261b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQueryFileTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQueryFileTest.scala @@ -53,7 +53,7 @@ abstract class HiveQueryFileTest extends HiveComparisonTest { testCases.sorted.foreach { case (testCaseName, testCaseFile) => if (blackList.map(_.r.pattern.matcher(testCaseName).matches()).reduceLeft(_||_)) { - logger.debug(s"Blacklisted test skipped $testCaseName") + logDebug(s"Blacklisted test skipped $testCaseName") } else if (realWhiteList.map(_.r.pattern.matcher(testCaseName).matches()).reduceLeft(_||_) || runAll) { // Build a test case and submit it to scala test framework... val queriesString = fileToString(testCaseFile) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index a8623b64c656f..fdb2f41f5a5b6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -19,9 +19,11 @@ package org.apache.spark.sql.hive.execution import scala.util.Try +import org.apache.spark.sql.{SchemaRDD, Row} +import org.apache.spark.sql.hive._ import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ -import org.apache.spark.sql.{SchemaRDD, Row} +import org.apache.spark.sql.{Row, SchemaRDD} case class TestData(a: Int, b: String) @@ -30,6 +32,21 @@ case class TestData(a: Int, b: String) */ class HiveQuerySuite extends HiveComparisonTest { + createQueryTest("null case", + "SELECT case when(true) then 1 else null end FROM src LIMIT 1") + + createQueryTest("single case", + """SELECT case when true then 1 else 2 end FROM src LIMIT 1""") + + createQueryTest("double case", + """SELECT case when 1 = 2 then 1 when 2 = 2 then 3 else 2 end FROM src LIMIT 1""") + + createQueryTest("case else null", + """SELECT case when 1 = 2 then 1 when 2 = 2 then 3 else null end FROM src LIMIT 1""") + + createQueryTest("having no references", + "SELECT key FROM src GROUP BY key HAVING COUNT(*) > 1") + createQueryTest("boolean = number", """ |SELECT @@ -43,8 +60,8 @@ class HiveQuerySuite extends HiveComparisonTest { """.stripMargin) test("CREATE TABLE AS runs once") { - hql("CREATE TABLE foo AS SELECT 1 FROM src LIMIT 1").collect() - assert(hql("SELECT COUNT(*) FROM foo").collect().head.getLong(0) === 1, + sql("CREATE TABLE foo AS SELECT 1 FROM src LIMIT 1").collect() + assert(sql("SELECT COUNT(*) FROM foo").collect().head.getLong(0) === 1, "Incorrect number of rows in created table") } @@ -58,12 +75,14 @@ class HiveQuerySuite extends HiveComparisonTest { "SELECT 2 / 1, 1 / 2, 1 / 3, 1 / COUNT(*) FROM src LIMIT 1") test("Query expressed in SQL") { + setConf("spark.sql.dialect", "sql") assert(sql("SELECT 1").collect() === Array(Seq(1))) + setConf("spark.sql.dialect", "hiveql") + } test("Query expressed in HiveQL") { - hql("FROM src SELECT key").collect() - hiveql("FROM src SELECT key").collect() + sql("FROM src SELECT key").collect() } createQueryTest("Constant Folding Optimization for AVG_SUM_COUNT", @@ -179,12 +198,12 @@ class HiveQuerySuite extends HiveComparisonTest { "SELECT * FROM src LATERAL VIEW explode(map(key+3,key+4)) D as k, v") test("sampling") { - hql("SELECT * FROM src TABLESAMPLE(0.1 PERCENT) s") + sql("SELECT * FROM src TABLESAMPLE(0.1 PERCENT) s") } test("SchemaRDD toString") { - hql("SHOW TABLES").toString - hql("SELECT * FROM src").toString + sql("SHOW TABLES").toString + sql("SELECT * FROM src").toString } createQueryTest("case statements with key #1", @@ -212,8 +231,8 @@ class HiveQuerySuite extends HiveComparisonTest { "SELECT (CASE WHEN key > 2 THEN 3 WHEN 2 > key THEN 2 ELSE 0 END) FROM src WHERE key < 15") test("implement identity function using case statement") { - val actual = hql("SELECT (CASE key WHEN key THEN key END) FROM src").collect().toSet - val expected = hql("SELECT key FROM src").collect().toSet + val actual = sql("SELECT (CASE key WHEN key THEN key END) FROM src").collect().toSet + val expected = sql("SELECT key FROM src").collect().toSet assert(actual === expected) } @@ -221,7 +240,7 @@ class HiveQuerySuite extends HiveComparisonTest { // See https://github.com/apache/spark/pull/1055#issuecomment-45820167 for a discussion. ignore("non-boolean conditions in a CaseWhen are illegal") { intercept[Exception] { - hql("SELECT (CASE WHEN key > 2 THEN 3 WHEN 1 THEN 2 ELSE 0 END) FROM src").collect() + sql("SELECT (CASE WHEN key > 2 THEN 3 WHEN 1 THEN 2 ELSE 0 END) FROM src").collect() } } @@ -233,10 +252,10 @@ class HiveQuerySuite extends HiveComparisonTest { TestHive.sparkContext.parallelize( TestData(1, "str1") :: TestData(2, "str2") :: Nil) - testData.registerAsTable("REGisteredTABle") + testData.registerTempTable("REGisteredTABle") assertResult(Array(Array(2, "str2"))) { - hql("SELECT tablealias.A, TABLEALIAS.b FROM reGisteredTABle TableAlias " + + sql("SELECT tablealias.A, TABLEALIAS.b FROM reGisteredTABle TableAlias " + "WHERE TableAliaS.a > 1").collect() } } @@ -247,9 +266,9 @@ class HiveQuerySuite extends HiveComparisonTest { } test("SPARK-1704: Explain commands as a SchemaRDD") { - hql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") + sql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") - val rdd = hql("explain select key, count(value) from src group by key") + val rdd = sql("explain select key, count(value) from src group by key") assert(isExplanation(rdd)) TestHive.reset() @@ -258,9 +277,9 @@ class HiveQuerySuite extends HiveComparisonTest { test("SPARK-2180: HAVING support in GROUP BY clauses (positive)") { val fixture = List(("foo", 2), ("bar", 1), ("foo", 4), ("bar", 3)) .zipWithIndex.map {case Pair(Pair(value, attr), key) => HavingRow(key, value, attr)} - TestHive.sparkContext.parallelize(fixture).registerAsTable("having_test") + TestHive.sparkContext.parallelize(fixture).registerTempTable("having_test") val results = - hql("SELECT value, max(attr) AS attr FROM having_test GROUP BY value HAVING attr > 3") + sql("SELECT value, max(attr) AS attr FROM having_test GROUP BY value HAVING attr > 3") .collect() .map(x => Pair(x.getString(0), x.getInt(1))) @@ -269,39 +288,39 @@ class HiveQuerySuite extends HiveComparisonTest { } test("SPARK-2180: HAVING with non-boolean clause raises no exceptions") { - hql("select key, count(*) c from src group by key having c").collect() + sql("select key, count(*) c from src group by key having c").collect() } test("SPARK-2225: turn HAVING without GROUP BY into a simple filter") { - assert(hql("select key from src having key > 490").collect().size < 100) + assert(sql("select key from src having key > 490").collect().size < 100) } test("Query Hive native command execution result") { val tableName = "test_native_commands" assertResult(0) { - hql(s"DROP TABLE IF EXISTS $tableName").count() + sql(s"DROP TABLE IF EXISTS $tableName").count() } assertResult(0) { - hql(s"CREATE TABLE $tableName(key INT, value STRING)").count() + sql(s"CREATE TABLE $tableName(key INT, value STRING)").count() } assert( - hql("SHOW TABLES") + sql("SHOW TABLES") .select('result) .collect() .map(_.getString(0)) .contains(tableName)) - assert(isExplanation(hql(s"EXPLAIN SELECT key, COUNT(*) FROM $tableName GROUP BY key"))) + assert(isExplanation(sql(s"EXPLAIN SELECT key, COUNT(*) FROM $tableName GROUP BY key"))) TestHive.reset() } test("Exactly once semantics for DDL and command statements") { val tableName = "test_exactly_once" - val q0 = hql(s"CREATE TABLE $tableName(key INT, value STRING)") + val q0 = sql(s"CREATE TABLE $tableName(key INT, value STRING)") // If the table was not created, the following assertion would fail assert(Try(table(tableName)).isSuccess) @@ -311,9 +330,9 @@ class HiveQuerySuite extends HiveComparisonTest { } test("DESCRIBE commands") { - hql(s"CREATE TABLE test_describe_commands1 (key INT, value STRING) PARTITIONED BY (dt STRING)") + sql(s"CREATE TABLE test_describe_commands1 (key INT, value STRING) PARTITIONED BY (dt STRING)") - hql( + sql( """FROM src INSERT OVERWRITE TABLE test_describe_commands1 PARTITION (dt='2008-06-08') |SELECT key, value """.stripMargin) @@ -328,7 +347,7 @@ class HiveQuerySuite extends HiveComparisonTest { Array("# col_name", "data_type", "comment"), Array("dt", "string", null)) ) { - hql("DESCRIBE test_describe_commands1") + sql("DESCRIBE test_describe_commands1") .select('col_name, 'data_type, 'comment) .collect() } @@ -343,14 +362,14 @@ class HiveQuerySuite extends HiveComparisonTest { Array("# col_name", "data_type", "comment"), Array("dt", "string", null)) ) { - hql("DESCRIBE default.test_describe_commands1") + sql("DESCRIBE default.test_describe_commands1") .select('col_name, 'data_type, 'comment) .collect() } // Describe a column is a native command assertResult(Array(Array("value", "string", "from deserializer"))) { - hql("DESCRIBE test_describe_commands1 value") + sql("DESCRIBE test_describe_commands1 value") .select('result) .collect() .map(_.getString(0).split("\t").map(_.trim)) @@ -358,7 +377,7 @@ class HiveQuerySuite extends HiveComparisonTest { // Describe a column is a native command assertResult(Array(Array("value", "string", "from deserializer"))) { - hql("DESCRIBE default.test_describe_commands1 value") + sql("DESCRIBE default.test_describe_commands1 value") .select('result) .collect() .map(_.getString(0).split("\t").map(_.trim)) @@ -376,7 +395,7 @@ class HiveQuerySuite extends HiveComparisonTest { Array("", "", ""), Array("dt", "string", "None")) ) { - hql("DESCRIBE test_describe_commands1 PARTITION (dt='2008-06-08')") + sql("DESCRIBE test_describe_commands1 PARTITION (dt='2008-06-08')") .select('result) .collect() .map(_.getString(0).split("\t").map(_.trim)) @@ -387,7 +406,7 @@ class HiveQuerySuite extends HiveComparisonTest { TestHive.sparkContext.parallelize( TestData(1, "str1") :: TestData(1, "str2") :: Nil) - testData.registerAsTable("test_describe_commands2") + testData.registerTempTable("test_describe_commands2") assertResult( Array( @@ -395,16 +414,16 @@ class HiveQuerySuite extends HiveComparisonTest { Array("a", "IntegerType", null), Array("b", "StringType", null)) ) { - hql("DESCRIBE test_describe_commands2") + sql("DESCRIBE test_describe_commands2") .select('col_name, 'data_type, 'comment) .collect() } } test("SPARK-2263: Insert Map values") { - hql("CREATE TABLE m(value MAP)") - hql("INSERT OVERWRITE TABLE m SELECT MAP(key, value) FROM src LIMIT 10") - hql("SELECT * FROM m").collect().zip(hql("SELECT * FROM src LIMIT 10").collect()).map { + sql("CREATE TABLE m(value MAP)") + sql("INSERT OVERWRITE TABLE m SELECT MAP(key, value) FROM src LIMIT 10") + sql("SELECT * FROM m").collect().zip(sql("SELECT * FROM src LIMIT 10").collect()).map { case (Row(map: Map[_, _]), Row(key: Int, value: String)) => assert(map.size === 1) assert(map.head === (key, value)) @@ -416,19 +435,19 @@ class HiveQuerySuite extends HiveComparisonTest { val testKey = "spark.sql.key.usedfortestonly" val testVal = "val0,val_1,val2.3,my_table" - hql(s"set $testKey=$testVal") - assert(get(testKey, testVal + "_") == testVal) + sql(s"set $testKey=$testVal") + assert(getConf(testKey, testVal + "_") == testVal) - hql("set mapred.reduce.tasks=20") - assert(get("mapred.reduce.tasks", "0") == "20") - hql("set mapred.reduce.tasks = 40") - assert(get("mapred.reduce.tasks", "0") == "40") + sql("set some.property=20") + assert(getConf("some.property", "0") == "20") + sql("set some.property = 40") + assert(getConf("some.property", "0") == "40") - hql(s"set $testKey=$testVal") - assert(get(testKey, "0") == testVal) + sql(s"set $testKey=$testVal") + assert(getConf(testKey, "0") == testVal) - hql(s"set $testKey=") - assert(get(testKey, "0") == "") + sql(s"set $testKey=") + assert(getConf(testKey, "0") == "") } test("SET commands semantics for a HiveContext") { @@ -436,63 +455,62 @@ class HiveQuerySuite extends HiveComparisonTest { val testKey = "spark.sql.key.usedfortestonly" val testVal = "test.val.0" val nonexistentKey = "nonexistent" - def collectResults(rdd: SchemaRDD): Set[(String, String)] = - rdd.collect().map { case Row(key: String, value: String) => key -> value }.toSet clear() // "set" itself returns all config variables currently specified in SQLConf. - assert(hql("SET").collect().size == 0) + // TODO: Should we be listing the default here always? probably... + assert(sql("SET").collect().size == 0) - assertResult(Set(testKey -> testVal)) { - collectResults(hql(s"SET $testKey=$testVal")) + assertResult(Array(s"$testKey=$testVal")) { + sql(s"SET $testKey=$testVal").collect().map(_.getString(0)) } assert(hiveconf.get(testKey, "") == testVal) - assertResult(Set(testKey -> testVal)) { - collectResults(hql("SET")) + assertResult(Array(s"$testKey=$testVal")) { + sql(s"SET $testKey=$testVal").collect().map(_.getString(0)) } - hql(s"SET ${testKey + testKey}=${testVal + testVal}") + sql(s"SET ${testKey + testKey}=${testVal + testVal}") assert(hiveconf.get(testKey + testKey, "") == testVal + testVal) - assertResult(Set(testKey -> testVal, (testKey + testKey) -> (testVal + testVal))) { - collectResults(hql("SET")) + assertResult(Array(s"$testKey=$testVal", s"${testKey + testKey}=${testVal + testVal}")) { + sql(s"SET").collect().map(_.getString(0)) } // "set key" - assertResult(Set(testKey -> testVal)) { - collectResults(hql(s"SET $testKey")) + assertResult(Array(s"$testKey=$testVal")) { + sql(s"SET $testKey").collect().map(_.getString(0)) } - assertResult(Set(nonexistentKey -> "")) { - collectResults(hql(s"SET $nonexistentKey")) + assertResult(Array(s"$nonexistentKey=")) { + sql(s"SET $nonexistentKey").collect().map(_.getString(0)) } - // Assert that sql() should have the same effects as hql() by repeating the above using sql(). + // Assert that sql() should have the same effects as sql() by repeating the above using sql(). clear() assert(sql("SET").collect().size == 0) - assertResult(Set(testKey -> testVal)) { - collectResults(sql(s"SET $testKey=$testVal")) + assertResult(Array(s"$testKey=$testVal")) { + sql(s"SET $testKey=$testVal").collect().map(_.getString(0)) } assert(hiveconf.get(testKey, "") == testVal) - assertResult(Set(testKey -> testVal)) { - collectResults(sql("SET")) + assertResult(Array(s"$testKey=$testVal")) { + sql("SET").collect().map(_.getString(0)) } sql(s"SET ${testKey + testKey}=${testVal + testVal}") assert(hiveconf.get(testKey + testKey, "") == testVal + testVal) - assertResult(Set(testKey -> testVal, (testKey + testKey) -> (testVal + testVal))) { - collectResults(sql("SET")) + assertResult(Array(s"$testKey=$testVal", s"${testKey + testKey}=${testVal + testVal}")) { + sql("SET").collect().map(_.getString(0)) } - assertResult(Set(testKey -> testVal)) { - collectResults(sql(s"SET $testKey")) + assertResult(Array(s"$testKey=$testVal")) { + sql(s"SET $testKey").collect().map(_.getString(0)) } - assertResult(Set(nonexistentKey -> "")) { - collectResults(sql(s"SET $nonexistentKey")) + assertResult(Array(s"$nonexistentKey=")) { + sql(s"SET $nonexistentKey").collect().map(_.getString(0)) } clear() diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala index fb03db12a0b01..6b3ffd1c0ffe2 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala @@ -54,15 +54,15 @@ class HiveResolutionSuite extends HiveComparisonTest { test("case insensitivity with scala reflection") { // Test resolution with Scala Reflection TestHive.sparkContext.parallelize(Data(1, 2, Nested(1,2), Seq(Nested(1,2))) :: Nil) - .registerAsTable("caseSensitivityTest") + .registerTempTable("caseSensitivityTest") - hql("SELECT a, b, A, B, n.a, n.b, n.A, n.B FROM caseSensitivityTest") + sql("SELECT a, b, A, B, n.a, n.b, n.A, n.B FROM caseSensitivityTest") } test("nested repeated resolution") { TestHive.sparkContext.parallelize(Data(1, 2, Nested(1,2), Seq(Nested(1,2))) :: Nil) - .registerAsTable("nestedRepeatedTest") - assert(hql("SELECT nestedArray[0].a FROM nestedRepeatedTest").collect().head(0) === 1) + .registerTempTable("nestedRepeatedTest") + assert(sql("SELECT nestedArray[0].a FROM nestedRepeatedTest").collect().head(0) === 1) } /** diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala new file mode 100644 index 0000000000000..c5736723b47c0 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution + +class HiveTableScanSuite extends HiveComparisonTest { + + createQueryTest("partition_based_table_scan_with_different_serde", + """ + |CREATE TABLE part_scan_test (key STRING, value STRING) PARTITIONED BY (ds STRING) + |ROW FORMAT SERDE + |'org.apache.hadoop.hive.serde2.columnar.LazyBinaryColumnarSerDe' + |STORED AS RCFILE; + | + |FROM src + |INSERT INTO TABLE part_scan_test PARTITION (ds='2010-01-01') + |SELECT 100,100 LIMIT 1; + | + |ALTER TABLE part_scan_test SET SERDE + |'org.apache.hadoop.hive.serde2.columnar.ColumnarSerDe'; + | + |FROM src INSERT INTO TABLE part_scan_test PARTITION (ds='2010-01-02') + |SELECT 200,200 LIMIT 1; + | + |SELECT * from part_scan_test; + """.stripMargin) +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala index 7436de264a1e1..c3c18cf8ccac3 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala @@ -35,7 +35,7 @@ class HiveTypeCoercionSuite extends HiveComparisonTest { test("[SPARK-2210] boolean cast on boolean value should be removed") { val q = "select cast(cast(key=0 as boolean) as boolean) from src" - val project = TestHive.hql(q).queryExecution.executedPlan.collect { case e: Project => e }.head + val project = TestHive.sql(q).queryExecution.executedPlan.collect { case e: Project => e }.head // No cast expression introduced project.transformAllExpressions { case c: Cast => diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala index f944d010660eb..b6b8592344ef5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala @@ -37,7 +37,7 @@ import org.apache.hadoop.hive.ql.udf.generic.GenericUDF.DeferredObject */ class HiveUdfSuite extends HiveComparisonTest { - TestHive.hql( + TestHive.sql( """ |CREATE EXTERNAL TABLE hiveUdfTestTable ( | pair STRUCT @@ -48,16 +48,16 @@ class HiveUdfSuite extends HiveComparisonTest { """.stripMargin.format(classOf[PairSerDe].getName) ) - TestHive.hql( + TestHive.sql( "ALTER TABLE hiveUdfTestTable ADD IF NOT EXISTS PARTITION(partition='testUdf') LOCATION '%s'" .format(this.getClass.getClassLoader.getResource("data/files/testUdf").getFile) ) - TestHive.hql("CREATE TEMPORARY FUNCTION testUdf AS '%s'".format(classOf[PairUdf].getName)) + TestHive.sql("CREATE TEMPORARY FUNCTION testUdf AS '%s'".format(classOf[PairUdf].getName)) - TestHive.hql("SELECT testUdf(pair) FROM hiveUdfTestTable") + TestHive.sql("SELECT testUdf(pair) FROM hiveUdfTestTable") - TestHive.hql("DROP TEMPORARY FUNCTION IF EXISTS testUdf") + TestHive.sql("DROP TEMPORARY FUNCTION IF EXISTS testUdf") } class TestPair(x: Int, y: Int) extends Writable with Serializable { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala index 34d8a061ccc83..1a6dbc0ce0c0d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala @@ -27,7 +27,7 @@ import scala.collection.JavaConversions._ */ class PruningSuite extends HiveComparisonTest { // MINOR HACK: You must run a query before calling reset the first time. - TestHive.hql("SHOW TABLES") + TestHive.sql("SHOW TABLES") // Column/partition pruning is not implemented for `InMemoryColumnarTableScan` yet, need to reset // the environment to ensure all referenced tables in this suites are not cached in-memory. diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala new file mode 100644 index 0000000000000..635a9fb0d56cb --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution + +import scala.reflect.ClassTag + +import org.apache.spark.sql.{SQLConf, QueryTest} +import org.apache.spark.sql.execution.{BroadcastHashJoin, ShuffledHashJoin} +import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.hive.test.TestHive._ + +/** + * A collection of hive query tests where we generate the answers ourselves instead of depending on + * Hive to generate them (in contrast to HiveQuerySuite). Often this is because the query is + * valid, but Hive currently cannot execute it. + */ +class SQLQuerySuite extends QueryTest { + test("ordering not in select") { + checkAnswer( + sql("SELECT key FROM src ORDER BY value"), + sql("SELECT key FROM (SELECT key, value FROM src ORDER BY value) a").collect().toSeq) + } + + test("ordering not in agg") { + checkAnswer( + sql("SELECT key FROM src GROUP BY key, value ORDER BY value"), + sql(""" + SELECT key + FROM ( + SELECT key, value + FROM src + GROUP BY key, value + ORDER BY value) a""").collect().toSeq) + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala index 91ad59d7f82c0..6f57fe8958387 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.parquet +import java.io.File + import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, FunSuite} import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Row} @@ -27,6 +29,8 @@ import org.apache.spark.util.Utils // Implicits import org.apache.spark.sql.hive.test.TestHive._ +case class Cases(lower: String, UPPER: String) + class HiveParquetSuite extends FunSuite with BeforeAndAfterAll with BeforeAndAfterEach { val dirname = Utils.createTempDir() @@ -35,9 +39,9 @@ class HiveParquetSuite extends FunSuite with BeforeAndAfterAll with BeforeAndAft override def beforeAll() { // write test data - ParquetTestData.writeFile + ParquetTestData.writeFile() testRDD = parquetFile(ParquetTestData.testDir.toString) - testRDD.registerAsTable("testsource") + testRDD.registerTempTable("testsource") } override def afterAll() { @@ -55,35 +59,49 @@ class HiveParquetSuite extends FunSuite with BeforeAndAfterAll with BeforeAndAft Utils.deleteRecursively(dirname) } + test("Case insensitive attribute names") { + val tempFile = File.createTempFile("parquet", "") + tempFile.delete() + sparkContext.parallelize(1 to 10) + .map(_.toString) + .map(i => Cases(i, i)) + .saveAsParquetFile(tempFile.getCanonicalPath) + + parquetFile(tempFile.getCanonicalPath).registerTempTable("cases") + sql("SELECT upper FROM cases").collect().map(_.getString(0)) === (1 to 10).map(_.toString) + sql("SELECT LOWER FROM cases").collect().map(_.getString(0)) === (1 to 10).map(_.toString) + } + test("SELECT on Parquet table") { - val rdd = hql("SELECT * FROM testsource").collect() + val rdd = sql("SELECT * FROM testsource").collect() assert(rdd != null) assert(rdd.forall(_.size == 6)) } test("Simple column projection + filter on Parquet table") { - val rdd = hql("SELECT myboolean, mylong FROM testsource WHERE myboolean=true").collect() + val rdd = sql("SELECT myboolean, mylong FROM testsource WHERE myboolean=true").collect() assert(rdd.size === 5, "Filter returned incorrect number of rows") assert(rdd.forall(_.getBoolean(0)), "Filter returned incorrect Boolean field value") } test("Converting Hive to Parquet Table via saveAsParquetFile") { - hql("SELECT * FROM src").saveAsParquetFile(dirname.getAbsolutePath) - parquetFile(dirname.getAbsolutePath).registerAsTable("ptable") - val rddOne = hql("SELECT * FROM src").collect().sortBy(_.getInt(0)) - val rddTwo = hql("SELECT * from ptable").collect().sortBy(_.getInt(0)) + sql("SELECT * FROM src").saveAsParquetFile(dirname.getAbsolutePath) + parquetFile(dirname.getAbsolutePath).registerTempTable("ptable") + val rddOne = sql("SELECT * FROM src").collect().sortBy(_.getInt(0)) + val rddTwo = sql("SELECT * from ptable").collect().sortBy(_.getInt(0)) + compareRDDs(rddOne, rddTwo, "src (Hive)", Seq("key:Int", "value:String")) } test("INSERT OVERWRITE TABLE Parquet table") { - hql("SELECT * FROM testsource").saveAsParquetFile(dirname.getAbsolutePath) - parquetFile(dirname.getAbsolutePath).registerAsTable("ptable") + sql("SELECT * FROM testsource").saveAsParquetFile(dirname.getAbsolutePath) + parquetFile(dirname.getAbsolutePath).registerTempTable("ptable") // let's do three overwrites for good measure - hql("INSERT OVERWRITE TABLE ptable SELECT * FROM testsource").collect() - hql("INSERT OVERWRITE TABLE ptable SELECT * FROM testsource").collect() - hql("INSERT OVERWRITE TABLE ptable SELECT * FROM testsource").collect() - val rddCopy = hql("SELECT * FROM ptable").collect() - val rddOrig = hql("SELECT * FROM testsource").collect() + sql("INSERT OVERWRITE TABLE ptable SELECT * FROM testsource").collect() + sql("INSERT OVERWRITE TABLE ptable SELECT * FROM testsource").collect() + sql("INSERT OVERWRITE TABLE ptable SELECT * FROM testsource").collect() + val rddCopy = sql("SELECT * FROM ptable").collect() + val rddOrig = sql("SELECT * FROM testsource").collect() assert(rddCopy.size === rddOrig.size, "INSERT OVERWRITE changed size of table??") compareRDDs(rddOrig, rddCopy, "testsource", ParquetTestData.testSchemaFieldNames) } diff --git a/streaming/pom.xml b/streaming/pom.xml index f60697ce745b7..1072f74aea0d9 100644 --- a/streaming/pom.xml +++ b/streaming/pom.xml @@ -28,7 +28,7 @@ org.apache.spark spark-streaming_2.10 - streaming + streaming jar Spark Project Streaming @@ -58,6 +58,11 @@ scalacheck_${scala.binary.version} test + + junit + junit + test + com.novocode junit-interface diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala index ac56ff709c1c4..b780282bdac37 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala @@ -35,7 +35,6 @@ class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time) extends Logging with Serializable { val master = ssc.sc.master val framework = ssc.sc.appName - val sparkHome = ssc.sc.getSparkHome.getOrElse(null) val jars = ssc.sc.jars val graph = ssc.graph val checkpointDir = ssc.checkpointDir diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala index 09be3a50d2dfa..1f0244c251eba 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala @@ -138,7 +138,7 @@ private[streaming] abstract class ReceiverSupervisor( onReceiverStop(message, error) } catch { case t: Throwable => - stop("Error stopping receiver " + streamId, Some(t)) + logError("Error stopping receiver " + streamId + t.getStackTraceString) } } diff --git a/tools/pom.xml b/tools/pom.xml index c0ee8faa7a615..97abb6b2b63e0 100644 --- a/tools/pom.xml +++ b/tools/pom.xml @@ -27,7 +27,7 @@ org.apache.spark spark-tools_2.10 - tools + tools jar Spark Project Tools diff --git a/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala b/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala index 566983675bff5..bcf6d43ab34eb 100644 --- a/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala +++ b/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala @@ -68,12 +68,11 @@ object GenerateMIMAIgnore { for (className <- classes) { try { val classSymbol = mirror.classSymbol(Class.forName(className, false, classLoader)) - val moduleSymbol = mirror.staticModule(className) // TODO: see if it is necessary. + val moduleSymbol = mirror.staticModule(className) val directlyPrivateSpark = isPackagePrivate(classSymbol) || isPackagePrivateModule(moduleSymbol) - val developerApi = isDeveloperApi(classSymbol) - val experimental = isExperimental(classSymbol) - + val developerApi = isDeveloperApi(classSymbol) || isDeveloperApi(moduleSymbol) + val experimental = isExperimental(classSymbol) || isExperimental(moduleSymbol) /* Inner classes defined within a private[spark] class or object are effectively invisible, so we account for them as package private. */ lazy val indirectlyPrivateSpark = { @@ -87,10 +86,9 @@ object GenerateMIMAIgnore { } if (directlyPrivateSpark || indirectlyPrivateSpark || developerApi || experimental) { ignoredClasses += className - } else { - // check if this class has package-private/annotated members. - ignoredMembers ++= getAnnotatedOrPackagePrivateMembers(classSymbol) } + // check if this class has package-private/annotated members. + ignoredMembers ++= getAnnotatedOrPackagePrivateMembers(classSymbol) } catch { case _: Throwable => println("Error instrumenting class:" + className) @@ -115,9 +113,11 @@ object GenerateMIMAIgnore { } private def getAnnotatedOrPackagePrivateMembers(classSymbol: unv.ClassSymbol) = { - classSymbol.typeSignature.members - .filter(x => isPackagePrivate(x) || isDeveloperApi(x) || isExperimental(x)).map(_.fullName) ++ - getInnerFunctions(classSymbol) + classSymbol.typeSignature.members.filterNot(x => + x.fullName.startsWith("java") || x.fullName.startsWith("scala") + ).filter(x => + isPackagePrivate(x) || isDeveloperApi(x) || isExperimental(x) + ).map(_.fullName) ++ getInnerFunctions(classSymbol) } def main(args: Array[String]) { @@ -137,8 +137,7 @@ object GenerateMIMAIgnore { name.endsWith("$class") || name.contains("$sp") || name.contains("hive") || - name.contains("Hive") || - name.contains("repl") + name.contains("Hive") } /** diff --git a/tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala b/tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala index 8e8c35615a711..8a05fcb449aa6 100644 --- a/tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala +++ b/tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala @@ -61,10 +61,9 @@ object StoragePerfTester { for (i <- 1 to recordsPerMap) { writers(i % numOutputSplits).write(writeData) } - writers.map {w => - w.commit() + writers.map { w => + w.commitAndClose() total.addAndGet(w.fileSegment().length) - w.close() } shuffle.releaseWriters(true) diff --git a/yarn/alpha/pom.xml b/yarn/alpha/pom.xml index 5b13a1f002d6e..51744ece0412d 100644 --- a/yarn/alpha/pom.xml +++ b/yarn/alpha/pom.xml @@ -24,7 +24,7 @@ ../pom.xml - yarn-alpha + yarn-alpha org.apache.spark diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala index a1298e8f30b5c..1da0a1b675554 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala @@ -29,7 +29,7 @@ import org.apache.hadoop.fs._ import org.apache.hadoop.fs.permission.FsPermission import org.apache.hadoop.mapred.Master import org.apache.hadoop.mapreduce.MRJobConfig -import org.apache.hadoop.security.UserGroupInformation +import org.apache.hadoop.security.{Credentials, UserGroupInformation} import org.apache.hadoop.util.StringUtils import org.apache.hadoop.yarn.api._ import org.apache.hadoop.yarn.api.ApplicationConstants.Environment @@ -37,7 +37,7 @@ import org.apache.hadoop.yarn.api.protocolrecords._ import org.apache.hadoop.yarn.api.records._ import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.hadoop.yarn.util.Records -import org.apache.spark.{SparkException, Logging, SparkConf, SparkContext} +import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkContext, SparkException} /** * The entry point (starting in Client#main() and Client#run()) for launching Spark on YARN. The @@ -109,7 +109,7 @@ trait ClientBase extends Logging { if (amMem > maxMem) { val errorMessage = "Required AM memory (%d) is above the max threshold (%d) of this cluster." - .format(args.amMemory, maxMem) + .format(amMem, maxMem) logError(errorMessage) throw new IllegalArgumentException(errorMessage) } @@ -191,23 +191,11 @@ trait ClientBase extends Logging { // Upload Spark and the application JAR to the remote file system if necessary. Add them as // local resources to the application master. val fs = FileSystem.get(conf) - - val delegTokenRenewer = Master.getMasterPrincipal(conf) - if (UserGroupInformation.isSecurityEnabled()) { - if (delegTokenRenewer == null || delegTokenRenewer.length() == 0) { - val errorMessage = "Can't get Master Kerberos principal for use as renewer" - logError(errorMessage) - throw new SparkException(errorMessage) - } - } val dst = new Path(fs.getHomeDirectory(), appStagingDir) - val replication = sparkConf.getInt("spark.yarn.submit.file.replication", 3).toShort - - if (UserGroupInformation.isSecurityEnabled()) { - val dstFs = dst.getFileSystem(conf) - dstFs.addDelegationTokens(delegTokenRenewer, credentials) - } + val nns = ClientBase.getNameNodesToAccess(sparkConf) + dst + ClientBase.obtainTokensForNamenodes(nns, conf, credentials) + val replication = sparkConf.getInt("spark.yarn.submit.file.replication", 3).toShort val localResources = HashMap[String, LocalResource]() FileSystem.mkdirs(fs, dst, new FsPermission(STAGING_DIR_PERMISSION)) @@ -271,6 +259,14 @@ trait ClientBase extends Logging { localResources } + /** Get all application master environment variables set on this SparkConf */ + def getAppMasterEnv: Seq[(String, String)] = { + val prefix = "spark.yarn.appMasterEnv." + sparkConf.getAll.filter{case (k, v) => k.startsWith(prefix)} + .map{case (k, v) => (k.substring(prefix.length), v)} + } + + def setupLaunchEnv( localResources: HashMap[String, LocalResource], stagingDir: String): HashMap[String, String] = { @@ -288,6 +284,11 @@ trait ClientBase extends Logging { distCacheMgr.setDistFilesEnv(env) distCacheMgr.setDistArchivesEnv(env) + getAppMasterEnv.foreach { case (key, value) => + YarnSparkHadoopUtil.addToEnvironment(env, key, value, File.pathSeparator) + } + + // Keep this for backwards compatibility but users should move to the config sys.env.get("SPARK_YARN_USER_ENV").foreach { userEnvs => // Allow users to specify some environment variables. YarnSparkHadoopUtil.setEnvFromInputString(env, userEnvs, File.pathSeparator) @@ -417,6 +418,13 @@ trait ClientBase extends Logging { amContainer.setCommands(printableCommands) setupSecurityToken(amContainer) + + // send the acl settings into YARN to control who has access via YARN interfaces + val securityManager = new SecurityManager(sparkConf) + val acls = Map[ApplicationAccessType, String] ( + ApplicationAccessType.VIEW_APP -> securityManager.getViewAcls, + ApplicationAccessType.MODIFY_APP -> securityManager.getModifyAcls) + amContainer.setApplicationACLs(acls) amContainer } } @@ -614,4 +622,40 @@ object ClientBase extends Logging { YarnSparkHadoopUtil.addToEnvironment(env, Environment.CLASSPATH.name, path, File.pathSeparator) + /** + * Get the list of namenodes the user may access. + */ + private[yarn] def getNameNodesToAccess(sparkConf: SparkConf): Set[Path] = { + sparkConf.get("spark.yarn.access.namenodes", "").split(",").map(_.trim()).filter(!_.isEmpty) + .map(new Path(_)).toSet + } + + private[yarn] def getTokenRenewer(conf: Configuration): String = { + val delegTokenRenewer = Master.getMasterPrincipal(conf) + logDebug("delegation token renewer is: " + delegTokenRenewer) + if (delegTokenRenewer == null || delegTokenRenewer.length() == 0) { + val errorMessage = "Can't get Master Kerberos principal for use as renewer" + logError(errorMessage) + throw new SparkException(errorMessage) + } + delegTokenRenewer + } + + /** + * Obtains tokens for the namenodes passed in and adds them to the credentials. + */ + private[yarn] def obtainTokensForNamenodes(paths: Set[Path], conf: Configuration, + creds: Credentials) { + if (UserGroupInformation.isSecurityEnabled()) { + val delegTokenRenewer = getTokenRenewer(conf) + + paths.foreach { + dst => + val dstFs = dst.getFileSystem(conf) + logDebug("getting token for namenode: " + dst) + dstFs.addDelegationTokens(delegTokenRenewer, creds) + } + } + } + } diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnableUtil.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnableUtil.scala index 4ba7133a959ed..71a9e42846b2b 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnableUtil.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnableUtil.scala @@ -171,7 +171,11 @@ trait ExecutorRunnableUtil extends Logging { val extraCp = sparkConf.getOption("spark.executor.extraClassPath") ClientBase.populateClasspath(null, yarnConf, sparkConf, env, extraCp) - // Allow users to specify some environment variables + sparkConf.getExecutorEnv.foreach { case (key, value) => + YarnSparkHadoopUtil.addToEnvironment(env, key, value, File.pathSeparator) + } + + // Keep this for backwards compatibility but users should move to the config YarnSparkHadoopUtil.setEnvFromInputString(env, System.getenv("SPARK_YARN_USER_ENV"), File.pathSeparator) diff --git a/yarn/common/src/test/scala/org/apache/spark/deploy/yarn/ClientBaseSuite.scala b/yarn/common/src/test/scala/org/apache/spark/deploy/yarn/ClientBaseSuite.scala index 686714dc36488..68cc2890f3a22 100644 --- a/yarn/common/src/test/scala/org/apache/spark/deploy/yarn/ClientBaseSuite.scala +++ b/yarn/common/src/test/scala/org/apache/spark/deploy/yarn/ClientBaseSuite.scala @@ -31,6 +31,8 @@ import org.apache.hadoop.yarn.api.records.ContainerLaunchContext import org.apache.hadoop.yarn.conf.YarnConfiguration import org.mockito.Matchers._ import org.mockito.Mockito._ + + import org.scalatest.FunSuite import org.scalatest.Matchers @@ -38,7 +40,7 @@ import scala.collection.JavaConversions._ import scala.collection.mutable.{ HashMap => MutableHashMap } import scala.util.Try -import org.apache.spark.SparkConf +import org.apache.spark.{SparkException, SparkConf} import org.apache.spark.util.Utils class ClientBaseSuite extends FunSuite with Matchers { @@ -138,6 +140,57 @@ class ClientBaseSuite extends FunSuite with Matchers { } } + test("check access nns empty") { + val sparkConf = new SparkConf() + sparkConf.set("spark.yarn.access.namenodes", "") + val nns = ClientBase.getNameNodesToAccess(sparkConf) + nns should be(Set()) + } + + test("check access nns unset") { + val sparkConf = new SparkConf() + val nns = ClientBase.getNameNodesToAccess(sparkConf) + nns should be(Set()) + } + + test("check access nns") { + val sparkConf = new SparkConf() + sparkConf.set("spark.yarn.access.namenodes", "hdfs://nn1:8032") + val nns = ClientBase.getNameNodesToAccess(sparkConf) + nns should be(Set(new Path("hdfs://nn1:8032"))) + } + + test("check access nns space") { + val sparkConf = new SparkConf() + sparkConf.set("spark.yarn.access.namenodes", "hdfs://nn1:8032, ") + val nns = ClientBase.getNameNodesToAccess(sparkConf) + nns should be(Set(new Path("hdfs://nn1:8032"))) + } + + test("check access two nns") { + val sparkConf = new SparkConf() + sparkConf.set("spark.yarn.access.namenodes", "hdfs://nn1:8032,hdfs://nn2:8032") + val nns = ClientBase.getNameNodesToAccess(sparkConf) + nns should be(Set(new Path("hdfs://nn1:8032"), new Path("hdfs://nn2:8032"))) + } + + test("check token renewer") { + val hadoopConf = new Configuration() + hadoopConf.set("yarn.resourcemanager.address", "myrm:8033") + hadoopConf.set("yarn.resourcemanager.principal", "yarn/myrm:8032@SPARKTEST.COM") + val renewer = ClientBase.getTokenRenewer(hadoopConf) + renewer should be ("yarn/myrm:8032@SPARKTEST.COM") + } + + test("check token renewer default") { + val hadoopConf = new Configuration() + val caught = + intercept[SparkException] { + ClientBase.getTokenRenewer(hadoopConf) + } + assert(caught.getMessage === "Can't get Master Kerberos principal for use as renewer") + } + object Fixtures { val knownDefYarnAppCP: Seq[String] = diff --git a/yarn/pom.xml b/yarn/pom.xml index efb473aa1b261..3faaf053634d6 100644 --- a/yarn/pom.xml +++ b/yarn/pom.xml @@ -29,7 +29,7 @@ pom Spark Project YARN Parent POM - yarn + yarn diff --git a/yarn/stable/pom.xml b/yarn/stable/pom.xml index ceaf9f9d71001..b6c8456d06684 100644 --- a/yarn/stable/pom.xml +++ b/yarn/stable/pom.xml @@ -24,7 +24,7 @@ ../pom.xml - yarn-stable + yarn-stable org.apache.spark