diff --git a/.gitignore b/.gitignore index 20095dd97343e..9757054a50f9e 100644 --- a/.gitignore +++ b/.gitignore @@ -8,16 +8,19 @@ *.pyc .idea/ .idea_modules/ -sbt/*.jar +build/*.jar .settings .cache +cache .generated-mima* -/build/ work/ out/ .DS_Store third_party/libmesos.so third_party/libmesos.dylib +build/apache-maven* +build/zinc* +build/scala* conf/java-opts conf/*.sh conf/*.cmd diff --git a/README.md b/README.md index 8d57d50da96c9..16628bd406775 100644 --- a/README.md +++ b/README.md @@ -26,7 +26,7 @@ To build Spark and its example programs, run: (You do not need to do this if you downloaded a pre-built package.) More detailed documentation is available from the project site, at -["Building Spark with Maven"](http://spark.apache.org/docs/latest/building-with-maven.html). +["Building Spark with Maven"](http://spark.apache.org/docs/latest/building-spark.html). ## Interactive Scala Shell diff --git a/build/mvn b/build/mvn new file mode 100755 index 0000000000000..43471f83e904c --- /dev/null +++ b/build/mvn @@ -0,0 +1,149 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Determine the current working directory +_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" +# Preserve the calling directory +_CALLING_DIR="$(pwd)" + +# Installs any application tarball given a URL, the expected tarball name, +# and, optionally, a checkable binary path to determine if the binary has +# already been installed +## Arg1 - URL +## Arg2 - Tarball Name +## Arg3 - Checkable Binary +install_app() { + local remote_tarball="$1/$2" + local local_tarball="${_DIR}/$2" + local binary="${_DIR}/$3" + + # setup `curl` and `wget` silent options if we're running on Jenkins + local curl_opts="" + local wget_opts="" + if [ -n "$AMPLAB_JENKINS" ]; then + curl_opts="-s" + wget_opts="--quiet" + else + curl_opts="--progress-bar" + wget_opts="--progress=bar:force" + fi + + if [ -z "$3" -o ! -f "$binary" ]; then + # check if we already have the tarball + # check if we have curl installed + # download application + [ ! -f "${local_tarball}" ] && [ -n "`which curl 2>/dev/null`" ] && \ + echo "exec: curl ${curl_opts} ${remote_tarball}" && \ + curl ${curl_opts} "${remote_tarball}" > "${local_tarball}" + # if the file still doesn't exist, lets try `wget` and cross our fingers + [ ! -f "${local_tarball}" ] && [ -n "`which wget 2>/dev/null`" ] && \ + echo "exec: wget ${wget_opts} ${remote_tarball}" && \ + wget ${wget_opts} -O "${local_tarball}" "${remote_tarball}" + # if both were unsuccessful, exit + [ ! -f "${local_tarball}" ] && \ + echo -n "ERROR: Cannot download $2 with cURL or wget; " && \ + echo "please install manually and try again." && \ + exit 2 + cd "${_DIR}" && tar -xzf "$2" + rm -rf "$local_tarball" + fi +} + +# Install maven under the build/ folder +install_mvn() { + install_app \ + "http://apache.claz.org/maven/maven-3/3.2.3/binaries" \ + "apache-maven-3.2.3-bin.tar.gz" \ + "apache-maven-3.2.3/bin/mvn" + MVN_BIN="${_DIR}/apache-maven-3.2.3/bin/mvn" +} + +# Install zinc under the build/ folder +install_zinc() { + local zinc_path="zinc-0.3.5.3/bin/zinc" + [ ! -f "${zinc_path}" ] && ZINC_INSTALL_FLAG=1 + install_app \ + "http://downloads.typesafe.com/zinc/0.3.5.3" \ + "zinc-0.3.5.3.tgz" \ + "${zinc_path}" + ZINC_BIN="${_DIR}/${zinc_path}" +} + +# Determine the Scala version from the root pom.xml file, set the Scala URL, +# and, with that, download the specific version of Scala necessary under +# the build/ folder +install_scala() { + # determine the Scala version used in Spark + local scala_version=`grep "scala.version" "${_DIR}/../pom.xml" | \ + head -1 | cut -f2 -d'>' | cut -f1 -d'<'` + local scala_bin="${_DIR}/scala-${scala_version}/bin/scala" + + install_app \ + "http://downloads.typesafe.com/scala/${scala_version}" \ + "scala-${scala_version}.tgz" \ + "scala-${scala_version}/bin/scala" + + SCALA_COMPILER="$(cd "$(dirname ${scala_bin})/../lib" && pwd)/scala-compiler.jar" + SCALA_LIBRARY="$(cd "$(dirname ${scala_bin})/../lib" && pwd)/scala-library.jar" +} + +# Determines if a given application is already installed. If not, will attempt +# to install +## Arg1 - application name +## Arg2 - Alternate path to local install under build/ dir +check_and_install_app() { + # create the local environment variable in uppercase + local app_bin="`echo $1 | awk '{print toupper(\$0)}'`_BIN" + # some black magic to set the generated app variable (i.e. MVN_BIN) into the + # environment + eval "${app_bin}=`which $1 2>/dev/null`" + + if [ -z "`which $1 2>/dev/null`" ]; then + install_$1 + fi +} + +# Setup healthy defaults for the Zinc port if none were provided from +# the environment +ZINC_PORT=${ZINC_PORT:-"3030"} + +# Check and install all applications necessary to build Spark +check_and_install_app "mvn" + +# Install the proper version of Scala and Zinc for the build +install_zinc +install_scala + +# Reset the current working directory +cd "${_CALLING_DIR}" + +# Now that zinc is ensured to be installed, check its status and, if its +# not running or just installed, start it +if [ -n "${ZINC_INSTALL_FLAG}" -o -z "`${ZINC_BIN} -status`" ]; then + ${ZINC_BIN} -shutdown + ${ZINC_BIN} -start -port ${ZINC_PORT} \ + -scala-compiler "${SCALA_COMPILER}" \ + -scala-library "${SCALA_LIBRARY}" &>/dev/null +fi + +# Set any `mvn` options if not already present +export MAVEN_OPTS=${MAVEN_OPTS:-"-Xmx2g -XX:MaxPermSize=512M -XX:ReservedCodeCacheSize=512m"} + +# Last, call the `mvn` command as usual +${MVN_BIN} "$@" diff --git a/build/sbt b/build/sbt new file mode 100755 index 0000000000000..28ebb64f7197c --- /dev/null +++ b/build/sbt @@ -0,0 +1,128 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# When creating new tests for Spark SQL Hive, the HADOOP_CLASSPATH must contain the hive jars so +# that we can run Hive to generate the golden answer. This is not required for normal development +# or testing. +for i in "$HIVE_HOME"/lib/* +do HADOOP_CLASSPATH="$HADOOP_CLASSPATH:$i" +done +export HADOOP_CLASSPATH + +realpath () { +( + TARGET_FILE="$1" + + cd "$(dirname "$TARGET_FILE")" + TARGET_FILE="$(basename "$TARGET_FILE")" + + COUNT=0 + while [ -L "$TARGET_FILE" -a $COUNT -lt 100 ] + do + TARGET_FILE="$(readlink "$TARGET_FILE")" + cd $(dirname "$TARGET_FILE") + TARGET_FILE="$(basename $TARGET_FILE)" + COUNT=$(($COUNT + 1)) + done + + echo "$(pwd -P)/"$TARGET_FILE"" +) +} + +. "$(dirname "$(realpath "$0")")"/sbt-launch-lib.bash + + +declare -r noshare_opts="-Dsbt.global.base=project/.sbtboot -Dsbt.boot.directory=project/.boot -Dsbt.ivy.home=project/.ivy" +declare -r sbt_opts_file=".sbtopts" +declare -r etc_sbt_opts_file="/etc/sbt/sbtopts" + +usage() { + cat < path to global settings/plugins directory (default: ~/.sbt) + -sbt-boot path to shared boot directory (default: ~/.sbt/boot in 0.11 series) + -ivy path to local Ivy repository (default: ~/.ivy2) + -mem set memory options (default: $sbt_mem, which is $(get_mem_opts $sbt_mem)) + -no-share use all local caches; no sharing + -no-global uses global caches, but does not use global ~/.sbt directory. + -jvm-debug Turn on JVM debugging, open at the given port. + -batch Disable interactive mode + + # sbt version (default: from project/build.properties if present, else latest release) + -sbt-version use the specified version of sbt + -sbt-jar use the specified jar as the sbt launcher + -sbt-rc use an RC version of sbt + -sbt-snapshot use a snapshot version of sbt + + # java version (default: java from PATH, currently $(java -version 2>&1 | grep version)) + -java-home alternate JAVA_HOME + + # jvm options and output control + JAVA_OPTS environment variable, if unset uses "$java_opts" + SBT_OPTS environment variable, if unset uses "$default_sbt_opts" + .sbtopts if this file exists in the current directory, it is + prepended to the runner args + /etc/sbt/sbtopts if this file exists, it is prepended to the runner args + -Dkey=val pass -Dkey=val directly to the java runtime + -J-X pass option -X directly to the java runtime + (-J is stripped) + -S-X add -X to sbt's scalacOptions (-S is stripped) + -PmavenProfiles Enable a maven profile for the build. + +In the case of duplicated or conflicting options, the order above +shows precedence: JAVA_OPTS lowest, command line options highest. +EOM +} + +process_my_args () { + while [[ $# -gt 0 ]]; do + case "$1" in + -no-colors) addJava "-Dsbt.log.noformat=true" && shift ;; + -no-share) addJava "$noshare_opts" && shift ;; + -no-global) addJava "-Dsbt.global.base=$(pwd)/project/.sbtboot" && shift ;; + -sbt-boot) require_arg path "$1" "$2" && addJava "-Dsbt.boot.directory=$2" && shift 2 ;; + -sbt-dir) require_arg path "$1" "$2" && addJava "-Dsbt.global.base=$2" && shift 2 ;; + -debug-inc) addJava "-Dxsbt.inc.debug=true" && shift ;; + -batch) exec - + - + diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index a074ab8ece1b7..6e4edc7c80d7a 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -76,6 +76,8 @@ private[spark] class MapOutputTrackerMasterActor(tracker: MapOutputTrackerMaster */ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging { private val timeout = AkkaUtils.askTimeout(conf) + private val retryAttempts = AkkaUtils.numRetries(conf) + private val retryIntervalMs = AkkaUtils.retryWaitMs(conf) /** Set to the MapOutputTrackerActor living on the driver. */ var trackerActor: ActorRef = _ @@ -108,8 +110,7 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging */ protected def askTracker(message: Any): Any = { try { - val future = trackerActor.ask(message)(timeout) - Await.result(future, timeout) + AkkaUtils.askWithReply(message, trackerActor, retryAttempts, retryIntervalMs, timeout) } catch { case e: Exception => logError("Error communicating with MapOutputTracker", e) diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala index 49dae5231a92c..ec82d09cd079b 100644 --- a/core/src/main/scala/org/apache/spark/SecurityManager.scala +++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala @@ -151,8 +151,8 @@ private[spark] class SecurityManager(sparkConf: SparkConf) extends Logging with private val authOn = sparkConf.getBoolean("spark.authenticate", false) // keep spark.ui.acls.enable for backwards compatibility with 1.0 - private var aclsOn = sparkConf.getOption("spark.acls.enable").getOrElse( - sparkConf.get("spark.ui.acls.enable", "false")).toBoolean + private var aclsOn = + sparkConf.getBoolean("spark.acls.enable", sparkConf.getBoolean("spark.ui.acls.enable", false)) // admin acls should be set before view or modify acls private var adminAcls: Set[String] = diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 6656df44d3599..43436a1697000 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -395,7 +395,7 @@ object SparkEnv extends Logging { val sparkProperties = (conf.getAll ++ schedulerMode).sorted // System properties that are not java classpaths - val systemProperties = System.getProperties.iterator.toSeq + val systemProperties = Utils.getSystemProperties.toSeq val otherProperties = systemProperties.filter { case (k, _) => k != "java.class.path" && !k.startsWith("spark.") }.sorted diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala b/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala index 86e94931300f8..71b26737b8c02 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala @@ -80,7 +80,7 @@ private[spark] object JavaUtils { prev match { case Some(k) => underlying match { - case mm: mutable.Map[a, _] => + case mm: mutable.Map[A, _] => mm remove k prev = None case _ => diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 81fa0770bbaf9..e8a5cfc746fed 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -123,6 +123,7 @@ private[spark] class Master( override def preStart() { logInfo("Starting Spark master at " + masterUrl) + logInfo(s"Running Spark version ${org.apache.spark.SPARK_VERSION}") // Listen for remote client disconnection events, since they don't go through Akka's watch() context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) webUi.bind() diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index 86a87ec22235e..f0f3da5eec4df 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -155,6 +155,7 @@ private[spark] class Worker( assert(!registered) logInfo("Starting Spark worker %s:%d with %d cores, %s RAM".format( host, port, cores, Utils.megabytesToString(memory))) + logInfo(s"Running Spark version ${org.apache.spark.SPARK_VERSION}") logInfo("Spark home: " + sparkHome) createWorkDir() context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index a157e36e2286e..0001c2329c83a 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -131,7 +131,7 @@ class HadoopRDD[K, V]( // used to build JobTracker ID private val createTime = new Date() - private val shouldCloneJobConf = sc.conf.get("spark.hadoop.cloneConf", "false").toBoolean + private val shouldCloneJobConf = sc.conf.getBoolean("spark.hadoop.cloneConf", false) // Returns a JobConf that will be used on slaves to obtain input splits for Hadoop reads. protected def getJobConf(): JobConf = { diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index f47c2d1fcdcc7..5118e2b911120 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -1146,15 +1146,20 @@ abstract class RDD[T: ClassTag]( if (num == 0) { Array.empty } else { - mapPartitions { items => + val mapRDDs = mapPartitions { items => // Priority keeps the largest elements, so let's reverse the ordering. val queue = new BoundedPriorityQueue[T](num)(ord.reverse) queue ++= util.collection.Utils.takeOrdered(items, num)(ord) Iterator.single(queue) - }.reduce { (queue1, queue2) => - queue1 ++= queue2 - queue1 - }.toArray.sorted(ord) + } + if (mapRDDs.partitions.size == 0) { + Array.empty + } else { + mapRDDs.reduce { (queue1, queue2) => + queue1 ++= queue2 + queue1 + }.toArray.sorted(ord) + } } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala index 819b51e12ad8c..4896ec845bbc9 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala @@ -19,6 +19,7 @@ package org.apache.spark.scheduler import java.nio.ByteBuffer +import scala.language.existentials import scala.util.control.NonFatal import org.apache.spark._ diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index 621a951c27d07..d2947dcea4f7c 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -26,6 +26,7 @@ import com.esotericsoftware.kryo.serializers.{JavaSerializer => KryoJavaSerializ import com.twitter.chill.{AllScalaRegistrar, EmptyScalaKryoInstantiator} import org.apache.spark._ +import org.apache.spark.api.python.PythonBroadcast import org.apache.spark.broadcast.HttpBroadcast import org.apache.spark.network.nio.{PutBlock, GotBlock, GetBlock} import org.apache.spark.scheduler.MapStatus @@ -90,6 +91,7 @@ class KryoSerializer(conf: SparkConf) // Allow sending SerializableWritable kryo.register(classOf[SerializableWritable[_]], new KryoJavaSerializer()) kryo.register(classOf[HttpBroadcast[_]], new KryoJavaSerializer()) + kryo.register(classOf[PythonBroadcast], new KryoJavaSerializer()) try { // Use the default classloader when calling the user registrator. diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index 7486cb6b1bbc0..b5022fe853c49 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -234,8 +234,9 @@ private[spark] object UIUtils extends Logging {

- + + {org.apache.spark.SPARK_VERSION} {title}

diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java index e5bdad6bda2fa..5ce299d05824b 100644 --- a/core/src/test/java/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java @@ -184,6 +184,7 @@ public void sortByKey() { Assert.assertEquals(new Tuple2(3, 2), sortedPairs.get(2)); } + @SuppressWarnings("unchecked") @Test public void repartitionAndSortWithinPartitions() { List> pairs = new ArrayList>(); @@ -491,6 +492,7 @@ public Integer call(Integer a, Integer b) { Assert.assertEquals(33, sum); } + @SuppressWarnings("unchecked") @Test public void aggregateByKey() { JavaPairRDD pairs = sc.parallelizePairs( @@ -1556,7 +1558,7 @@ static class Class2 {} @Test public void testRegisterKryoClasses() { SparkConf conf = new SparkConf(); - conf.registerKryoClasses(new Class[]{ Class1.class, Class2.class }); + conf.registerKryoClasses(new Class[]{ Class1.class, Class2.class }); Assert.assertEquals( Class1.class.getName() + "," + Class2.class.getName(), conf.get("spark.kryo.classesToRegister")); diff --git a/core/src/test/scala/org/apache/spark/DistributedSuite.scala b/core/src/test/scala/org/apache/spark/DistributedSuite.scala index 998f3008ec0ea..97ea3578aa8ba 100644 --- a/core/src/test/scala/org/apache/spark/DistributedSuite.scala +++ b/core/src/test/scala/org/apache/spark/DistributedSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark -import org.scalatest.BeforeAndAfter import org.scalatest.FunSuite import org.scalatest.concurrent.Timeouts._ import org.scalatest.Matchers @@ -29,16 +28,10 @@ class NotSerializableClass class NotSerializableExn(val notSer: NotSerializableClass) extends Throwable() {} -class DistributedSuite extends FunSuite with Matchers with BeforeAndAfter - with LocalSparkContext { +class DistributedSuite extends FunSuite with Matchers with LocalSparkContext { val clusterUrl = "local-cluster[2,1,512]" - after { - System.clearProperty("spark.reducer.maxMbInFlight") - System.clearProperty("spark.storage.memoryFraction") - } - test("task throws not serializable exception") { // Ensures that executors do not crash when an exn is not serializable. If executors crash, // this test will hang. Correct behavior is that executors don't crash but fail tasks @@ -84,15 +77,14 @@ class DistributedSuite extends FunSuite with Matchers with BeforeAndAfter } test("groupByKey where map output sizes exceed maxMbInFlight") { - System.setProperty("spark.reducer.maxMbInFlight", "1") - sc = new SparkContext(clusterUrl, "test") + val conf = new SparkConf().set("spark.reducer.maxMbInFlight", "1") + sc = new SparkContext(clusterUrl, "test", conf) // This data should be around 20 MB, so even with 4 mappers and 2 reducers, each map output // file should be about 2.5 MB val pairs = sc.parallelize(1 to 2000, 4).map(x => (x % 16, new Array[Byte](10000))) val groups = pairs.groupByKey(2).map(x => (x._1, x._2.size)).collect() assert(groups.length === 16) assert(groups.map(_._2).sum === 2000) - // Note that spark.reducer.maxMbInFlight will be cleared in the test suite's after{} block } test("accumulators") { @@ -210,7 +202,6 @@ class DistributedSuite extends FunSuite with Matchers with BeforeAndAfter } test("compute without caching when no partitions fit in memory") { - System.setProperty("spark.storage.memoryFraction", "0.0001") sc = new SparkContext(clusterUrl, "test") // data will be 4 million * 4 bytes = 16 MB in size, but our memoryFraction set the cache // to only 50 KB (0.0001 of 512 MB), so no partitions should fit in memory @@ -218,12 +209,11 @@ class DistributedSuite extends FunSuite with Matchers with BeforeAndAfter assert(data.count() === 4000000) assert(data.count() === 4000000) assert(data.count() === 4000000) - System.clearProperty("spark.storage.memoryFraction") } test("compute when only some partitions fit in memory") { - System.setProperty("spark.storage.memoryFraction", "0.01") - sc = new SparkContext(clusterUrl, "test") + val conf = new SparkConf().set("spark.storage.memoryFraction", "0.01") + sc = new SparkContext(clusterUrl, "test", conf) // data will be 4 million * 4 bytes = 16 MB in size, but our memoryFraction set the cache // to only 5 MB (0.01 of 512 MB), so not all of it will fit in memory; we use 20 partitions // to make sure that *some* of them do fit though @@ -231,7 +221,6 @@ class DistributedSuite extends FunSuite with Matchers with BeforeAndAfter assert(data.count() === 4000000) assert(data.count() === 4000000) assert(data.count() === 4000000) - System.clearProperty("spark.storage.memoryFraction") } test("passing environment variables to cluster") { diff --git a/core/src/test/scala/org/apache/spark/FileServerSuite.scala b/core/src/test/scala/org/apache/spark/FileServerSuite.scala index 49426545c767e..0f49ce4754fbb 100644 --- a/core/src/test/scala/org/apache/spark/FileServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/FileServerSuite.scala @@ -31,10 +31,11 @@ class FileServerSuite extends FunSuite with LocalSparkContext { @transient var tmpFile: File = _ @transient var tmpJarUrl: String = _ + def newConf: SparkConf = new SparkConf(loadDefaults = false).set("spark.authenticate", "false") + override def beforeEach() { super.beforeEach() resetSparkContext() - System.setProperty("spark.authenticate", "false") } override def beforeAll() { @@ -52,7 +53,6 @@ class FileServerSuite extends FunSuite with LocalSparkContext { val jarFile = new File(testTempDir, "test.jar") val jarStream = new FileOutputStream(jarFile) val jar = new JarOutputStream(jarStream, new java.util.jar.Manifest()) - System.setProperty("spark.authenticate", "false") val jarEntry = new JarEntry(textFile.getName) jar.putNextEntry(jarEntry) @@ -74,7 +74,7 @@ class FileServerSuite extends FunSuite with LocalSparkContext { } test("Distributing files locally") { - sc = new SparkContext("local[4]", "test") + sc = new SparkContext("local[4]", "test", newConf) sc.addFile(tmpFile.toString) val testData = Array((1,1), (1,1), (2,1), (3,5), (2,2), (3,0)) val result = sc.parallelize(testData).reduceByKey { @@ -108,7 +108,7 @@ class FileServerSuite extends FunSuite with LocalSparkContext { test("Distributing files locally using URL as input") { // addFile("file:///....") - sc = new SparkContext("local[4]", "test") + sc = new SparkContext("local[4]", "test", newConf) sc.addFile(new File(tmpFile.toString).toURI.toString) val testData = Array((1,1), (1,1), (2,1), (3,5), (2,2), (3,0)) val result = sc.parallelize(testData).reduceByKey { @@ -122,7 +122,7 @@ class FileServerSuite extends FunSuite with LocalSparkContext { } test ("Dynamically adding JARS locally") { - sc = new SparkContext("local[4]", "test") + sc = new SparkContext("local[4]", "test", newConf) sc.addJar(tmpJarUrl) val testData = Array((1, 1)) sc.parallelize(testData).foreach { x => @@ -133,7 +133,7 @@ class FileServerSuite extends FunSuite with LocalSparkContext { } test("Distributing files on a standalone cluster") { - sc = new SparkContext("local-cluster[1,1,512]", "test") + sc = new SparkContext("local-cluster[1,1,512]", "test", newConf) sc.addFile(tmpFile.toString) val testData = Array((1,1), (1,1), (2,1), (3,5), (2,2), (3,0)) val result = sc.parallelize(testData).reduceByKey { @@ -147,7 +147,7 @@ class FileServerSuite extends FunSuite with LocalSparkContext { } test ("Dynamically adding JARS on a standalone cluster") { - sc = new SparkContext("local-cluster[1,1,512]", "test") + sc = new SparkContext("local-cluster[1,1,512]", "test", newConf) sc.addJar(tmpJarUrl) val testData = Array((1,1)) sc.parallelize(testData).foreach { x => @@ -158,7 +158,7 @@ class FileServerSuite extends FunSuite with LocalSparkContext { } test ("Dynamically adding JARS on a standalone cluster using local: URL") { - sc = new SparkContext("local-cluster[1,1,512]", "test") + sc = new SparkContext("local-cluster[1,1,512]", "test", newConf) sc.addJar(tmpJarUrl.replace("file", "local")) val testData = Array((1,1)) sc.parallelize(testData).foreach { x => diff --git a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala index 41ed2bce55ce1..7584ae79fc920 100644 --- a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala +++ b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala @@ -40,12 +40,11 @@ class JobCancellationSuite extends FunSuite with Matchers with BeforeAndAfter override def afterEach() { super.afterEach() resetSparkContext() - System.clearProperty("spark.scheduler.mode") } test("local mode, FIFO scheduler") { - System.setProperty("spark.scheduler.mode", "FIFO") - sc = new SparkContext("local[2]", "test") + val conf = new SparkConf().set("spark.scheduler.mode", "FIFO") + sc = new SparkContext("local[2]", "test", conf) testCount() testTake() // Make sure we can still launch tasks. @@ -53,10 +52,10 @@ class JobCancellationSuite extends FunSuite with Matchers with BeforeAndAfter } test("local mode, fair scheduler") { - System.setProperty("spark.scheduler.mode", "FAIR") + val conf = new SparkConf().set("spark.scheduler.mode", "FAIR") val xmlPath = getClass.getClassLoader.getResource("fairscheduler.xml").getFile() - System.setProperty("spark.scheduler.allocation.file", xmlPath) - sc = new SparkContext("local[2]", "test") + conf.set("spark.scheduler.allocation.file", xmlPath) + sc = new SparkContext("local[2]", "test", conf) testCount() testTake() // Make sure we can still launch tasks. @@ -64,8 +63,8 @@ class JobCancellationSuite extends FunSuite with Matchers with BeforeAndAfter } test("cluster mode, FIFO scheduler") { - System.setProperty("spark.scheduler.mode", "FIFO") - sc = new SparkContext("local-cluster[2,1,512]", "test") + val conf = new SparkConf().set("spark.scheduler.mode", "FIFO") + sc = new SparkContext("local-cluster[2,1,512]", "test", conf) testCount() testTake() // Make sure we can still launch tasks. @@ -73,10 +72,10 @@ class JobCancellationSuite extends FunSuite with Matchers with BeforeAndAfter } test("cluster mode, fair scheduler") { - System.setProperty("spark.scheduler.mode", "FAIR") + val conf = new SparkConf().set("spark.scheduler.mode", "FAIR") val xmlPath = getClass.getClassLoader.getResource("fairscheduler.xml").getFile() - System.setProperty("spark.scheduler.allocation.file", xmlPath) - sc = new SparkContext("local-cluster[2,1,512]", "test") + conf.set("spark.scheduler.allocation.file", xmlPath) + sc = new SparkContext("local-cluster[2,1,512]", "test", conf) testCount() testTake() // Make sure we can still launch tasks. diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index 58a96245a9b53..f57921b768310 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -35,19 +35,15 @@ abstract class ShuffleSuite extends FunSuite with Matchers with LocalSparkContex conf.set("spark.test.noStageRetry", "true") test("groupByKey without compression") { - try { - System.setProperty("spark.shuffle.compress", "false") - sc = new SparkContext("local", "test", conf) - val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (2, 1)), 4) - val groups = pairs.groupByKey(4).collect() - assert(groups.size === 2) - val valuesFor1 = groups.find(_._1 == 1).get._2 - assert(valuesFor1.toList.sorted === List(1, 2, 3)) - val valuesFor2 = groups.find(_._1 == 2).get._2 - assert(valuesFor2.toList.sorted === List(1)) - } finally { - System.setProperty("spark.shuffle.compress", "true") - } + val myConf = conf.clone().set("spark.shuffle.compress", "false") + sc = new SparkContext("local", "test", myConf) + val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (2, 1)), 4) + val groups = pairs.groupByKey(4).collect() + assert(groups.size === 2) + val valuesFor1 = groups.find(_._1 == 1).get._2 + assert(valuesFor1.toList.sorted === List(1, 2, 3)) + val valuesFor2 = groups.find(_._1 == 2).get._2 + assert(valuesFor2.toList.sorted === List(1)) } test("shuffle non-zero block size") { diff --git a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala index 5d018ea9868a7..790976a5ac308 100644 --- a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala @@ -19,27 +19,20 @@ package org.apache.spark import org.scalatest.FunSuite import org.apache.spark.serializer.{KryoRegistrator, KryoSerializer} +import org.apache.spark.util.ResetSystemProperties import com.esotericsoftware.kryo.Kryo -class SparkConfSuite extends FunSuite with LocalSparkContext { +class SparkConfSuite extends FunSuite with LocalSparkContext with ResetSystemProperties { test("loading from system properties") { - try { - System.setProperty("spark.test.testProperty", "2") - val conf = new SparkConf() - assert(conf.get("spark.test.testProperty") === "2") - } finally { - System.clearProperty("spark.test.testProperty") - } + System.setProperty("spark.test.testProperty", "2") + val conf = new SparkConf() + assert(conf.get("spark.test.testProperty") === "2") } test("initializing without loading defaults") { - try { - System.setProperty("spark.test.testProperty", "2") - val conf = new SparkConf(false) - assert(!conf.contains("spark.test.testProperty")) - } finally { - System.clearProperty("spark.test.testProperty") - } + System.setProperty("spark.test.testProperty", "2") + val conf = new SparkConf(false) + assert(!conf.contains("spark.test.testProperty")) } test("named set methods") { @@ -117,23 +110,17 @@ class SparkConfSuite extends FunSuite with LocalSparkContext { test("nested property names") { // This wasn't supported by some external conf parsing libraries - try { - System.setProperty("spark.test.a", "a") - System.setProperty("spark.test.a.b", "a.b") - System.setProperty("spark.test.a.b.c", "a.b.c") - val conf = new SparkConf() - assert(conf.get("spark.test.a") === "a") - assert(conf.get("spark.test.a.b") === "a.b") - assert(conf.get("spark.test.a.b.c") === "a.b.c") - conf.set("spark.test.a.b", "A.B") - assert(conf.get("spark.test.a") === "a") - assert(conf.get("spark.test.a.b") === "A.B") - assert(conf.get("spark.test.a.b.c") === "a.b.c") - } finally { - System.clearProperty("spark.test.a") - System.clearProperty("spark.test.a.b") - System.clearProperty("spark.test.a.b.c") - } + System.setProperty("spark.test.a", "a") + System.setProperty("spark.test.a.b", "a.b") + System.setProperty("spark.test.a.b.c", "a.b.c") + val conf = new SparkConf() + assert(conf.get("spark.test.a") === "a") + assert(conf.get("spark.test.a.b") === "a.b") + assert(conf.get("spark.test.a.b.c") === "a.b.c") + conf.set("spark.test.a.b", "A.B") + assert(conf.get("spark.test.a") === "a") + assert(conf.get("spark.test.a.b") === "A.B") + assert(conf.get("spark.test.a.b.c") === "a.b.c") } test("register kryo classes through registerKryoClasses") { diff --git a/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala index 0390a2e4f1dbb..8ae4f243ec1ae 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala @@ -27,10 +27,13 @@ import org.apache.spark.scheduler.local.LocalBackend class SparkContextSchedulerCreationSuite extends FunSuite with LocalSparkContext with PrivateMethodTester with Logging { - def createTaskScheduler(master: String): TaskSchedulerImpl = { + def createTaskScheduler(master: String): TaskSchedulerImpl = + createTaskScheduler(master, new SparkConf()) + + def createTaskScheduler(master: String, conf: SparkConf): TaskSchedulerImpl = { // Create local SparkContext to setup a SparkEnv. We don't actually want to start() the // real schedulers, so we don't want to create a full SparkContext with the desired scheduler. - sc = new SparkContext("local", "test") + sc = new SparkContext("local", "test", conf) val createTaskSchedulerMethod = PrivateMethod[Tuple2[SchedulerBackend, TaskScheduler]]('createTaskScheduler) val (_, sched) = SparkContext invokePrivate createTaskSchedulerMethod(sc, master) @@ -102,19 +105,13 @@ class SparkContextSchedulerCreationSuite } test("local-default-parallelism") { - val defaultParallelism = System.getProperty("spark.default.parallelism") - System.setProperty("spark.default.parallelism", "16") - val sched = createTaskScheduler("local") + val conf = new SparkConf().set("spark.default.parallelism", "16") + val sched = createTaskScheduler("local", conf) sched.backend match { case s: LocalBackend => assert(s.defaultParallelism() === 16) case _ => fail() } - - Option(defaultParallelism) match { - case Some(v) => System.setProperty("spark.default.parallelism", v) - case _ => System.clearProperty("spark.default.parallelism") - } } test("simr") { @@ -155,9 +152,10 @@ class SparkContextSchedulerCreationSuite testYarn("yarn-client", "org.apache.spark.scheduler.cluster.YarnClientClusterScheduler") } - def testMesos(master: String, expectedClass: Class[_]) { + def testMesos(master: String, expectedClass: Class[_], coarse: Boolean) { + val conf = new SparkConf().set("spark.mesos.coarse", coarse.toString) try { - val sched = createTaskScheduler(master) + val sched = createTaskScheduler(master, conf) assert(sched.backend.getClass === expectedClass) } catch { case e: UnsatisfiedLinkError => @@ -168,17 +166,14 @@ class SparkContextSchedulerCreationSuite } test("mesos fine-grained") { - System.setProperty("spark.mesos.coarse", "false") - testMesos("mesos://localhost:1234", classOf[MesosSchedulerBackend]) + testMesos("mesos://localhost:1234", classOf[MesosSchedulerBackend], coarse = false) } test("mesos coarse-grained") { - System.setProperty("spark.mesos.coarse", "true") - testMesos("mesos://localhost:1234", classOf[CoarseMesosSchedulerBackend]) + testMesos("mesos://localhost:1234", classOf[CoarseMesosSchedulerBackend], coarse = true) } test("mesos with zookeeper") { - System.setProperty("spark.mesos.coarse", "false") - testMesos("zk://localhost:1234,localhost:2345", classOf[MesosSchedulerBackend]) + testMesos("zk://localhost:1234,localhost:2345", classOf[MesosSchedulerBackend], coarse = false) } } diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala index 1362022104195..8b3c6871a7b39 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala @@ -23,55 +23,37 @@ import org.apache.hadoop.io.BytesWritable class SparkContextSuite extends FunSuite with LocalSparkContext { - /** Allows system properties to be changed in tests */ - private def withSystemProperty[T](property: String, value: String)(block: => T): T = { - val originalValue = System.getProperty(property) - try { - System.setProperty(property, value) - block - } finally { - if (originalValue == null) { - System.clearProperty(property) - } else { - System.setProperty(property, originalValue) - } - } - } - test("Only one SparkContext may be active at a time") { // Regression test for SPARK-4180 - withSystemProperty("spark.driver.allowMultipleContexts", "false") { - val conf = new SparkConf().setAppName("test").setMaster("local") - sc = new SparkContext(conf) - // A SparkContext is already running, so we shouldn't be able to create a second one - intercept[SparkException] { new SparkContext(conf) } - // After stopping the running context, we should be able to create a new one - resetSparkContext() - sc = new SparkContext(conf) - } + val conf = new SparkConf().setAppName("test").setMaster("local") + .set("spark.driver.allowMultipleContexts", "false") + sc = new SparkContext(conf) + // A SparkContext is already running, so we shouldn't be able to create a second one + intercept[SparkException] { new SparkContext(conf) } + // After stopping the running context, we should be able to create a new one + resetSparkContext() + sc = new SparkContext(conf) } test("Can still construct a new SparkContext after failing to construct a previous one") { - withSystemProperty("spark.driver.allowMultipleContexts", "false") { - // This is an invalid configuration (no app name or master URL) - intercept[SparkException] { - new SparkContext(new SparkConf()) - } - // Even though those earlier calls failed, we should still be able to create a new context - sc = new SparkContext(new SparkConf().setMaster("local").setAppName("test")) + val conf = new SparkConf().set("spark.driver.allowMultipleContexts", "false") + // This is an invalid configuration (no app name or master URL) + intercept[SparkException] { + new SparkContext(conf) } + // Even though those earlier calls failed, we should still be able to create a new context + sc = new SparkContext(conf.setMaster("local").setAppName("test")) } test("Check for multiple SparkContexts can be disabled via undocumented debug option") { - withSystemProperty("spark.driver.allowMultipleContexts", "true") { - var secondSparkContext: SparkContext = null - try { - val conf = new SparkConf().setAppName("test").setMaster("local") - sc = new SparkContext(conf) - secondSparkContext = new SparkContext(conf) - } finally { - Option(secondSparkContext).foreach(_.stop()) - } + var secondSparkContext: SparkContext = null + try { + val conf = new SparkConf().setAppName("test").setMaster("local") + .set("spark.driver.allowMultipleContexts", "true") + sc = new SparkContext(conf) + secondSparkContext = new SparkContext(conf) + } finally { + Option(secondSparkContext).foreach(_.stop()) } } diff --git a/core/src/test/scala/org/apache/spark/api/python/PythonBroadcastSuite.scala b/core/src/test/scala/org/apache/spark/api/python/PythonBroadcastSuite.scala new file mode 100644 index 0000000000000..8959a843dbd7d --- /dev/null +++ b/core/src/test/scala/org/apache/spark/api/python/PythonBroadcastSuite.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.python + +import scala.io.Source + +import java.io.{PrintWriter, File} + +import org.scalatest.{Matchers, FunSuite} + +import org.apache.spark.{SharedSparkContext, SparkConf} +import org.apache.spark.serializer.KryoSerializer +import org.apache.spark.util.Utils + +// This test suite uses SharedSparkContext because we need a SparkEnv in order to deserialize +// a PythonBroadcast: +class PythonBroadcastSuite extends FunSuite with Matchers with SharedSparkContext { + test("PythonBroadcast can be serialized with Kryo (SPARK-4882)") { + val tempDir = Utils.createTempDir() + val broadcastedString = "Hello, world!" + def assertBroadcastIsValid(broadcast: PythonBroadcast): Unit = { + val source = Source.fromFile(broadcast.path) + val contents = source.mkString + source.close() + contents should be (broadcastedString) + } + try { + val broadcastDataFile: File = { + val file = new File(tempDir, "broadcastData") + val printWriter = new PrintWriter(file) + printWriter.write(broadcastedString) + printWriter.close() + file + } + val broadcast = new PythonBroadcast(broadcastDataFile.getAbsolutePath) + assertBroadcastIsValid(broadcast) + val conf = new SparkConf().set("spark.kryo.registrationRequired", "true") + val deserializedBroadcast = + Utils.clone[PythonBroadcast](broadcast, new KryoSerializer(conf).newInstance()) + assertBroadcastIsValid(deserializedBroadcast) + } finally { + Utils.deleteRecursively(tempDir) + } + } +} diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index eb7bd7ab3986e..5eda2d41f0e6d 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -23,11 +23,13 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark._ import org.apache.spark.deploy.SparkSubmit._ -import org.apache.spark.util.Utils +import org.apache.spark.util.{ResetSystemProperties, Utils} import org.scalatest.FunSuite import org.scalatest.Matchers -class SparkSubmitSuite extends FunSuite with Matchers { +// Note: this suite mixes in ResetSystemProperties because SparkSubmit.main() sets a bunch +// of properties that neeed to be cleared after tests. +class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties { def beforeAll() { System.setProperty("spark.testing", "true") } diff --git a/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala b/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala index ca226fd4e694f..f8bcde12a371a 100644 --- a/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala +++ b/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala @@ -24,14 +24,14 @@ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.scheduler.{SparkListenerTaskEnd, SparkListener} import org.scalatest.FunSuite -import org.scalatest.matchers.ShouldMatchers +import org.scalatest.Matchers import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{Path, FileSystem} import scala.collection.mutable.ArrayBuffer -class InputOutputMetricsSuite extends FunSuite with SharedSparkContext with ShouldMatchers { +class InputOutputMetricsSuite extends FunSuite with SharedSparkContext with Matchers { test("input metrics when reading text file with single split") { val file = new File(getClass.getSimpleName + ".txt") val pw = new PrintWriter(new FileWriter(file)) diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 436eea4f1fdcf..d6ec9e129cceb 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -739,7 +739,7 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F test("accumulator not calculated for resubmitted result stage") { //just for register - val accum = new Accumulator[Int](0, SparkContext.IntAccumulatorParam) + val accum = new Accumulator[Int](0, AccumulatorParam.IntAccumulatorParam) val finalRdd = new MyRDD(sc, 1, Nil) submit(finalRdd, Array(0)) completeWithAccumulator(accum.id, taskSets(0), Seq((Success, 42))) diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala index b276343cb412c..24f41bf8cccda 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala @@ -26,9 +26,10 @@ import org.scalatest.Matchers import org.apache.spark.{LocalSparkContext, SparkContext} import org.apache.spark.executor.TaskMetrics +import org.apache.spark.util.ResetSystemProperties -class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers - with BeforeAndAfter with BeforeAndAfterAll { +class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers with BeforeAndAfter + with BeforeAndAfterAll with ResetSystemProperties { /** Length of time to wait while draining listener events. */ val WAIT_TIMEOUT_MILLIS = 10000 @@ -37,10 +38,6 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers sc = new SparkContext("local", "SparkListenerSuite") } - override def afterAll() { - System.clearProperty("spark.akka.frameSize") - } - test("basic creation and shutdown of LiveListenerBus") { val counter = new BasicJobCounter val bus = new LiveListenerBus diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala index 5768a3a733f00..3aab5a156ee77 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala @@ -21,7 +21,7 @@ import java.nio.ByteBuffer import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FunSuite} -import org.apache.spark.{LocalSparkContext, SparkContext, SparkEnv} +import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkEnv} import org.apache.spark.storage.TaskResultBlockId /** @@ -55,27 +55,20 @@ class ResultDeletingTaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedule /** * Tests related to handling task results (both direct and indirect). */ -class TaskResultGetterSuite extends FunSuite with BeforeAndAfter with BeforeAndAfterAll - with LocalSparkContext { +class TaskResultGetterSuite extends FunSuite with BeforeAndAfter with LocalSparkContext { - override def beforeAll { - // Set the Akka frame size to be as small as possible (it must be an integer, so 1 is as small - // as we can make it) so the tests don't take too long. - System.setProperty("spark.akka.frameSize", "1") - } - - override def afterAll { - System.clearProperty("spark.akka.frameSize") - } + // Set the Akka frame size to be as small as possible (it must be an integer, so 1 is as small + // as we can make it) so the tests don't take too long. + def conf: SparkConf = new SparkConf().set("spark.akka.frameSize", "1") test("handling results smaller than Akka frame size") { - sc = new SparkContext("local", "test") + sc = new SparkContext("local", "test", conf) val result = sc.parallelize(Seq(1), 1).map(x => 2 * x).reduce((x, y) => x) assert(result === 2) } test("handling results larger than Akka frame size") { - sc = new SparkContext("local", "test") + sc = new SparkContext("local", "test", conf) val akkaFrameSize = sc.env.actorSystem.settings.config.getBytes("akka.remote.netty.tcp.maximum-frame-size").toInt val result = sc.parallelize(Seq(1), 1).map(x => 1.to(akkaFrameSize).toArray).reduce((x, y) => x) @@ -89,7 +82,7 @@ class TaskResultGetterSuite extends FunSuite with BeforeAndAfter with BeforeAndA test("task retried if result missing from block manager") { // Set the maximum number of task failures to > 0, so that the task set isn't aborted // after the result is missing. - sc = new SparkContext("local[1,2]", "test") + sc = new SparkContext("local[1,2]", "test", conf) // If this test hangs, it's probably because no resource offers were made after the task // failed. val scheduler: TaskSchedulerImpl = sc.taskScheduler match { diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala index 7532da88c6065..40aaf9dd1f1e9 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala @@ -162,12 +162,12 @@ class TaskSchedulerImplSuite extends FunSuite with LocalSparkContext with Loggin } test("Fair Scheduler Test") { - sc = new SparkContext("local", "TaskSchedulerImplSuite") + val xmlPath = getClass.getClassLoader.getResource("fairscheduler.xml").getFile() + val conf = new SparkConf().set("spark.scheduler.allocation.file", xmlPath) + sc = new SparkContext("local", "TaskSchedulerImplSuite", conf) val taskScheduler = new TaskSchedulerImpl(sc) val taskSet = FakeTask.createTaskSet(1) - val xmlPath = getClass.getClassLoader.getResource("fairscheduler.xml").getFile() - System.setProperty("spark.scheduler.allocation.file", xmlPath) val rootPool = new Pool("", SchedulingMode.FAIR, 0, 0) val schedulableBuilder = new FairSchedulableBuilder(rootPool, sc.conf) schedulableBuilder.buildPools() diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 5554efbcbadf8..ffe6f039145ea 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -33,7 +33,7 @@ import akka.util.Timeout import org.mockito.Mockito.{mock, when} -import org.scalatest.{BeforeAndAfter, FunSuite, Matchers, PrivateMethodTester} +import org.scalatest._ import org.scalatest.concurrent.Eventually._ import org.scalatest.concurrent.Timeouts._ @@ -44,18 +44,17 @@ import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} import org.apache.spark.shuffle.hash.HashShuffleManager import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat -import org.apache.spark.util.{AkkaUtils, ByteBufferInputStream, SizeEstimator, Utils} +import org.apache.spark.util._ -class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter - with PrivateMethodTester { +class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach + with PrivateMethodTester with ResetSystemProperties { private val conf = new SparkConf(false) var store: BlockManager = null var store2: BlockManager = null var actorSystem: ActorSystem = null var master: BlockManagerMaster = null - var oldArch: String = null conf.set("spark.authenticate", "false") val securityMgr = new SecurityManager(conf) val mapOutputTracker = new MapOutputTrackerMaster(conf) @@ -79,13 +78,13 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter manager } - before { + override def beforeEach(): Unit = { val (actorSystem, boundPort) = AkkaUtils.createActorSystem( "test", "localhost", 0, conf = conf, securityManager = securityMgr) this.actorSystem = actorSystem // Set the arch to 64-bit and compressedOops to true to get a deterministic test-case - oldArch = System.setProperty("os.arch", "amd64") + System.setProperty("os.arch", "amd64") conf.set("os.arch", "amd64") conf.set("spark.test.useCompressedOops", "true") conf.set("spark.driver.port", boundPort.toString) @@ -100,7 +99,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter SizeEstimator invokePrivate initialize() } - after { + override def afterEach(): Unit = { if (store != null) { store.stop() store = null @@ -113,14 +112,6 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter actorSystem.awaitTermination() actorSystem = null master = null - - if (oldArch != null) { - conf.set("os.arch", oldArch) - } else { - System.clearProperty("os.arch") - } - - System.clearProperty("spark.test.useCompressedOops") } test("StorageLevel object caching") { diff --git a/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala index 7bca1711ae226..6bbf72e929dcb 100644 --- a/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala @@ -31,7 +31,7 @@ import org.apache.spark.storage.BlockManagerId /** * Test the AkkaUtils with various security settings. */ -class AkkaUtilsSuite extends FunSuite with LocalSparkContext { +class AkkaUtilsSuite extends FunSuite with LocalSparkContext with ResetSystemProperties { test("remote fetch security bad password") { val conf = new SparkConf diff --git a/core/src/test/scala/org/apache/spark/util/ResetSystemProperties.scala b/core/src/test/scala/org/apache/spark/util/ResetSystemProperties.scala new file mode 100644 index 0000000000000..d4b92f33dd9e6 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/ResetSystemProperties.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import java.util.Properties + +import org.scalatest.{BeforeAndAfterEach, Suite} + +/** + * Mixin for automatically resetting system properties that are modified in ScalaTest tests. + * This resets the properties after each individual test. + * + * The order in which fixtures are mixed in affects the order in which they are invoked by tests. + * If we have a suite `MySuite extends FunSuite with Foo with Bar`, then + * Bar's `super` is Foo, so Bar's beforeEach() will and afterEach() methods will be invoked first + * by the rest runner. + * + * This means that ResetSystemProperties should appear as the last trait in test suites that it's + * mixed into in order to ensure that the system properties snapshot occurs as early as possible. + * ResetSystemProperties calls super.afterEach() before performing its own cleanup, ensuring that + * the old properties are restored as late as possible. + * + * See the "Composing fixtures by stacking traits" section at + * http://www.scalatest.org/user_guide/sharing_fixtures for more details about this pattern. + */ +private[spark] trait ResetSystemProperties extends BeforeAndAfterEach { this: Suite => + var oldProperties: Properties = null + + override def beforeEach(): Unit = { + oldProperties = new Properties(System.getProperties) + super.beforeEach() + } + + override def afterEach(): Unit = { + try { + super.afterEach() + } finally { + System.setProperties(oldProperties) + oldProperties = null + } + } +} diff --git a/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala b/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala index 0ea2d13a83505..7424c2e91d4f2 100644 --- a/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala @@ -17,9 +17,7 @@ package org.apache.spark.util -import org.scalatest.BeforeAndAfterAll -import org.scalatest.FunSuite -import org.scalatest.PrivateMethodTester +import org.scalatest.{BeforeAndAfterEach, BeforeAndAfterAll, FunSuite, PrivateMethodTester} class DummyClass1 {} @@ -46,20 +44,12 @@ class DummyString(val arr: Array[Char]) { } class SizeEstimatorSuite - extends FunSuite with BeforeAndAfterAll with PrivateMethodTester { + extends FunSuite with BeforeAndAfterEach with PrivateMethodTester with ResetSystemProperties { - var oldArch: String = _ - var oldOops: String = _ - - override def beforeAll() { + override def beforeEach() { // Set the arch to 64-bit and compressedOops to true to get a deterministic test-case - oldArch = System.setProperty("os.arch", "amd64") - oldOops = System.setProperty("spark.test.useCompressedOops", "true") - } - - override def afterAll() { - resetOrClear("os.arch", oldArch) - resetOrClear("spark.test.useCompressedOops", oldOops) + System.setProperty("os.arch", "amd64") + System.setProperty("spark.test.useCompressedOops", "true") } test("simple classes") { @@ -122,7 +112,7 @@ class SizeEstimatorSuite } test("32-bit arch") { - val arch = System.setProperty("os.arch", "x86") + System.setProperty("os.arch", "x86") val initialize = PrivateMethod[Unit]('initialize) SizeEstimator invokePrivate initialize() @@ -131,14 +121,13 @@ class SizeEstimatorSuite assertResult(48)(SizeEstimator.estimate(DummyString("a"))) assertResult(48)(SizeEstimator.estimate(DummyString("ab"))) assertResult(56)(SizeEstimator.estimate(DummyString("abcdefgh"))) - resetOrClear("os.arch", arch) } // NOTE: The String class definition varies across JDK versions (1.6 vs. 1.7) and vendors // (Sun vs IBM). Use a DummyString class to make tests deterministic. test("64-bit arch with no compressed oops") { - val arch = System.setProperty("os.arch", "amd64") - val oops = System.setProperty("spark.test.useCompressedOops", "false") + System.setProperty("os.arch", "amd64") + System.setProperty("spark.test.useCompressedOops", "false") val initialize = PrivateMethod[Unit]('initialize) SizeEstimator invokePrivate initialize() @@ -146,16 +135,5 @@ class SizeEstimatorSuite assertResult(64)(SizeEstimator.estimate(DummyString("a"))) assertResult(64)(SizeEstimator.estimate(DummyString("ab"))) assertResult(72)(SizeEstimator.estimate(DummyString("abcdefgh"))) - - resetOrClear("os.arch", arch) - resetOrClear("spark.test.useCompressedOops", oops) - } - - def resetOrClear(prop: String, oldValue: String) { - if (oldValue != null) { - System.setProperty(prop, oldValue) - } else { - System.clearProperty(prop) - } } } diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index f9d4bea823f7c..4544382094f96 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -31,7 +31,7 @@ import org.scalatest.FunSuite import org.apache.spark.SparkConf -class UtilsSuite extends FunSuite { +class UtilsSuite extends FunSuite with ResetSystemProperties { test("bytesToString") { assert(Utils.bytesToString(10) === "10.0 B") diff --git a/dev/create-release/create-release.sh b/dev/create-release/create-release.sh index 3b89aaba29609..b1b8cb44e098b 100755 --- a/dev/create-release/create-release.sh +++ b/dev/create-release/create-release.sh @@ -87,8 +87,8 @@ if [[ ! "$@" =~ --package-only ]]; then git commit -a -m "Preparing development version $next_ver" git push origin $GIT_TAG git push origin HEAD:$GIT_BRANCH - git checkout -f $GIT_TAG - + git checkout -f $GIT_TAG + # Using Nexus API documented here: # https://support.sonatype.com/entries/39720203-Uploading-to-a-Staging-Repository-via-REST-API echo "Creating Nexus staging repository" @@ -106,7 +106,7 @@ if [[ ! "$@" =~ --package-only ]]; then clean install ./dev/change-version-to-2.11.sh - + mvn -DskipTests -Dhadoop.version=2.2.0 -Dyarn.version=2.2.0 \ -Dscala-2.11 -Pyarn -Phive -Phadoop-2.2 -Pspark-ganglia-lgpl -Pkinesis-asl \ clean install @@ -174,7 +174,7 @@ make_binary_release() { NAME=$1 FLAGS=$2 cp -r spark spark-$RELEASE_VERSION-bin-$NAME - + cd spark-$RELEASE_VERSION-bin-$NAME # TODO There should probably be a flag to make-distribution to allow 2.11 support @@ -219,7 +219,7 @@ scp spark-* \ # Docs cd spark -sbt/sbt clean +build/sbt clean cd docs # Compile docs with Java 7 to use nicer format JAVA_HOME=$JAVA_7_HOME PRODUCTION=1 jekyll build diff --git a/dev/mima b/dev/mima index 40603166c21ae..bed5cd042634e 100755 --- a/dev/mima +++ b/dev/mima @@ -24,13 +24,13 @@ set -e FWDIR="$(cd "`dirname "$0"`"/..; pwd)" cd "$FWDIR" -echo -e "q\n" | sbt/sbt oldDeps/update +echo -e "q\n" | build/sbt oldDeps/update rm -f .generated-mima* -# Generate Mima Ignore is called twice, first with latest built jars +# Generate Mima Ignore is called twice, first with latest built jars # on the classpath and then again with previous version jars on the classpath. # Because of a bug in GenerateMIMAIgnore that when old jars are ahead on classpath -# it did not process the new classes (which are in assembly jar). +# it did not process the new classes (which are in assembly jar). ./bin/spark-class org.apache.spark.tools.GenerateMIMAIgnore export SPARK_CLASSPATH="`find lib_managed \( -name '*spark*jar' -a -type f \) | tr "\\n" ":"`" @@ -38,7 +38,7 @@ echo "SPARK_CLASSPATH=$SPARK_CLASSPATH" ./bin/spark-class org.apache.spark.tools.GenerateMIMAIgnore -echo -e "q\n" | sbt/sbt mima-report-binary-issues | grep -v -e "info.*Resolving" +echo -e "q\n" | build/sbt mima-report-binary-issues | grep -v -e "info.*Resolving" ret_val=$? if [ $ret_val != 0 ]; then diff --git a/dev/run-tests b/dev/run-tests index 9192cb7e169f3..20603fc089239 100755 --- a/dev/run-tests +++ b/dev/run-tests @@ -59,17 +59,17 @@ export SBT_MAVEN_PROFILES_ARGS="$SBT_MAVEN_PROFILES_ARGS -Pkinesis-asl" { if test -x "$JAVA_HOME/bin/java"; then declare java_cmd="$JAVA_HOME/bin/java" - else + else declare java_cmd=java fi - + # We can't use sed -r -e due to OS X / BSD compatibility; hence, all the parentheses. JAVA_VERSION=$( $java_cmd -version 2>&1 \ | grep -e "^java version" --max-count=1 \ | sed "s/java version \"\(.*\)\.\(.*\)\.\(.*\)\"/\1\2/" ) - + if [ "$JAVA_VERSION" -lt 18 ]; then echo "[warn] Java 8 tests will not run because JDK version is < 1.8." fi @@ -79,7 +79,7 @@ export SBT_MAVEN_PROFILES_ARGS="$SBT_MAVEN_PROFILES_ARGS -Pkinesis-asl" # Partial solution for SPARK-1455. if [ -n "$AMPLAB_JENKINS" ]; then git fetch origin master:master - + sql_diffs=$( git diff --name-only master \ | grep -e "^sql/" -e "^bin/spark-sql" -e "^sbin/start-thriftserver.sh" @@ -93,7 +93,7 @@ if [ -n "$AMPLAB_JENKINS" ]; then if [ -n "$sql_diffs" ]; then echo "[info] Detected changes in SQL. Will run Hive test suite." _RUN_SQL_TESTS=true - + if [ -z "$non_sql_diffs" ]; then echo "[info] Detected no changes except in SQL. Will only run SQL tests." _SQL_TESTS_ONLY=true @@ -151,7 +151,7 @@ CURRENT_BLOCK=$BLOCK_BUILD HIVE_12_BUILD_ARGS="$SBT_MAVEN_PROFILES_ARGS -Phive -Phive-thriftserver -Phive-0.12.0" echo "[info] Compile with Hive 0.12.0" echo -e "q\n" \ - | sbt/sbt $HIVE_12_BUILD_ARGS clean hive/compile hive-thriftserver/compile \ + | build/sbt $HIVE_12_BUILD_ARGS clean hive/compile hive-thriftserver/compile \ | grep -v -e "info.*Resolving" -e "warn.*Merging" -e "info.*Including" # Then build with default Hive version (0.13.1) because tests are based on this version @@ -160,7 +160,7 @@ CURRENT_BLOCK=$BLOCK_BUILD echo "[info] Building Spark with these arguments: $SBT_MAVEN_PROFILES_ARGS"\ " -Phive -Phive-thriftserver" echo -e "q\n" \ - | sbt/sbt $SBT_MAVEN_PROFILES_ARGS -Phive -Phive-thriftserver package assembly/assembly \ + | build/sbt $SBT_MAVEN_PROFILES_ARGS -Phive -Phive-thriftserver package assembly/assembly \ | grep -v -e "info.*Resolving" -e "warn.*Merging" -e "info.*Including" } @@ -177,7 +177,7 @@ CURRENT_BLOCK=$BLOCK_SPARK_UNIT_TESTS if [ -n "$_RUN_SQL_TESTS" ]; then SBT_MAVEN_PROFILES_ARGS="$SBT_MAVEN_PROFILES_ARGS -Phive -Phive-thriftserver" fi - + if [ -n "$_SQL_TESTS_ONLY" ]; then # This must be an array of individual arguments. Otherwise, having one long string # will be interpreted as a single test, which doesn't work. @@ -185,19 +185,19 @@ CURRENT_BLOCK=$BLOCK_SPARK_UNIT_TESTS else SBT_MAVEN_TEST_ARGS=("test") fi - + echo "[info] Running Spark tests with these arguments: $SBT_MAVEN_PROFILES_ARGS ${SBT_MAVEN_TEST_ARGS[@]}" - + # NOTE: echo "q" is needed because sbt on encountering a build file with failure # (either resolution or compilation) prompts the user for input either q, r, etc # to quit or retry. This echo is there to make it not block. - # NOTE: Do not quote $SBT_MAVEN_PROFILES_ARGS or else it will be interpreted as a + # NOTE: Do not quote $SBT_MAVEN_PROFILES_ARGS or else it will be interpreted as a # single argument! # "${SBT_MAVEN_TEST_ARGS[@]}" is cool because it's an array. # QUESTION: Why doesn't 'yes "q"' work? # QUESTION: Why doesn't 'grep -v -e "^\[info\] Resolving"' work? echo -e "q\n" \ - | sbt/sbt $SBT_MAVEN_PROFILES_ARGS "${SBT_MAVEN_TEST_ARGS[@]}" \ + | build/sbt $SBT_MAVEN_PROFILES_ARGS "${SBT_MAVEN_TEST_ARGS[@]}" \ | grep -v -e "info.*Resolving" -e "warn.*Merging" -e "info.*Including" } diff --git a/dev/scalastyle b/dev/scalastyle index 3a4df6e4bf1bc..86919227ed1ab 100755 --- a/dev/scalastyle +++ b/dev/scalastyle @@ -17,9 +17,9 @@ # limitations under the License. # -echo -e "q\n" | sbt/sbt -Phive -Phive-thriftserver scalastyle > scalastyle.txt +echo -e "q\n" | build/sbt -Phive -Phive-thriftserver scalastyle > scalastyle.txt # Check style with YARN built too -echo -e "q\n" | sbt/sbt -Pyarn -Phadoop-2.2 -Dhadoop.version=2.2.0 scalastyle \ +echo -e "q\n" | build/sbt -Pyarn -Phadoop-2.2 -Dhadoop.version=2.2.0 scalastyle \ >> scalastyle.txt ERRORS=$(cat scalastyle.txt | awk '{if($1~/error/)print}') diff --git a/docs/README.md b/docs/README.md index 119484038083f..8a54724c4beae 100644 --- a/docs/README.md +++ b/docs/README.md @@ -21,7 +21,7 @@ read those text files directly if you want. Start with index.md. The markdown code can be compiled to HTML using the [Jekyll tool](http://jekyllrb.com). `Jekyll` and a few dependencies must be installed for this to work. We recommend -installing via the Ruby Gem dependency manager. Since the exact HTML output +installing via the Ruby Gem dependency manager. Since the exact HTML output varies between versions of Jekyll and its dependencies, we list specific versions here in some cases: @@ -60,7 +60,7 @@ We use Sphinx to generate Python API docs, so you will need to install it by run ## API Docs (Scaladoc and Sphinx) -You can build just the Spark scaladoc by running `sbt/sbt doc` from the SPARK_PROJECT_ROOT directory. +You can build just the Spark scaladoc by running `build/sbt doc` from the SPARK_PROJECT_ROOT directory. Similarly, you can build just the PySpark docs by running `make html` from the SPARK_PROJECT_ROOT/python/docs directory. Documentation is only generated for classes that are listed as @@ -68,7 +68,7 @@ public in `__init__.py`. When you run `jekyll` in the `docs` directory, it will also copy over the scaladoc for the various Spark subprojects into the `docs` directory (and then also into the `_site` directory). We use a -jekyll plugin to run `sbt/sbt doc` before building the site so if you haven't run it (recently) it +jekyll plugin to run `build/sbt doc` before building the site so if you haven't run it (recently) it may take some time as it generates all of the scaladoc. The jekyll plugin also generates the PySpark docs [Sphinx](http://sphinx-doc.org/). diff --git a/docs/_plugins/copy_api_dirs.rb b/docs/_plugins/copy_api_dirs.rb index 4566a2fff562b..3c626a0b7f54b 100644 --- a/docs/_plugins/copy_api_dirs.rb +++ b/docs/_plugins/copy_api_dirs.rb @@ -25,8 +25,8 @@ curr_dir = pwd cd("..") - puts "Running 'sbt/sbt -Pkinesis-asl compile unidoc' from " + pwd + "; this may take a few minutes..." - puts `sbt/sbt -Pkinesis-asl compile unidoc` + puts "Running 'build/sbt -Pkinesis-asl compile unidoc' from " + pwd + "; this may take a few minutes..." + puts `build/sbt -Pkinesis-asl compile unidoc` puts "Moving back into docs dir." cd("docs") diff --git a/docs/building-spark.md b/docs/building-spark.md index 70165eabca435..c1bcd91b5b853 100644 --- a/docs/building-spark.md +++ b/docs/building-spark.md @@ -9,6 +9,15 @@ redirect_from: "building-with-maven.html" Building Spark using Maven requires Maven 3.0.4 or newer and Java 6+. +# Building with `build/mvn` + +Spark now comes packaged with a self-contained Maven installation to ease building and deployment of Spark from source located under the `build/` directory. This script will automatically download and setup all necessary build requirements ([Maven](https://maven.apache.org/), [Scala](http://www.scala-lang.org/), and [Zinc](https://github.com/typesafehub/zinc)) locally within the `build/` directory itself. It honors any `mvn` binary if present already, however, will pull down its own copy of Scala and Zinc regardless to ensure proper version requirements are met. `build/mvn` execution acts as a pass through to the `mvn` call allowing easy transition from previous build methods. As an example, one can build a version of Spark as follows: + +{% highlight bash %} +build/mvn -Pyarn -Phadoop-2.4 -Dhadoop.version=2.4.0 -DskipTests clean package +{% endhighlight %} + +Other build examples can be found below. # Setting up Maven's Memory Usage @@ -28,7 +37,9 @@ If you don't run this, you may see errors like the following: You can fix this by setting the `MAVEN_OPTS` variable as discussed before. -**Note:** *For Java 8 and above this step is not required.* +**Note:** +* *For Java 8 and above this step is not required.* +* *If using `build/mvn` and `MAVEN_OPTS` were not already set, the script will automate this for you.* # Specifying the Hadoop Version @@ -60,7 +71,7 @@ mvn -Dhadoop.version=2.0.0-mr1-cdh4.2.0 -DskipTests clean package mvn -Phadoop-0.23 -Dhadoop.version=0.23.7 -DskipTests clean package {% endhighlight %} -For Apache Hadoop 2.x, 0.23.x, Cloudera CDH, and other Hadoop versions with YARN, you can enable the "yarn" profile and optionally set the "yarn.version" property if it is different from "hadoop.version". As of Spark 1.3, Spark only supports YARN versions 2.2.0 and later. +You can enable the "yarn" profile and optionally set the "yarn.version" property if it is different from "hadoop.version". Spark only supports YARN versions 2.2.0 and later. Examples: @@ -84,7 +95,7 @@ mvn -Pyarn -Phadoop-2.3 -Dhadoop.version=2.3.0 -Dyarn.version=2.2.0 -DskipTests # Building With Hive and JDBC Support To enable Hive integration for Spark SQL along with its JDBC server and CLI, add the `-Phive` and `Phive-thriftserver` profiles to your existing build options. -By default Spark will build with Hive 0.13.1 bindings. You can also build for +By default Spark will build with Hive 0.13.1 bindings. You can also build for Hive 0.12.0 using the `-Phive-0.12.0` profile. {% highlight bash %} # Apache Hadoop 2.4.X with Hive 13 support @@ -106,7 +117,7 @@ supported in Scala 2.11 builds. # Spark Tests in Maven -Tests are run by default via the [ScalaTest Maven plugin](http://www.scalatest.org/user_guide/using_the_scalatest_maven_plugin). +Tests are run by default via the [ScalaTest Maven plugin](http://www.scalatest.org/user_guide/using_the_scalatest_maven_plugin). Some of the tests require Spark to be packaged first, so always run `mvn package` with `-DskipTests` the first time. The following is an example of a correct (build, test) sequence: @@ -124,7 +135,7 @@ We use the scala-maven-plugin which supports incremental and continuous compilat mvn scala:cc -should run continuous compilation (i.e. wait for changes). However, this has not been tested +should run continuous compilation (i.e. wait for changes). However, this has not been tested extensively. A couple of gotchas to note: * it only scans the paths `src/main` and `src/test` (see [docs](http://scala-tools.org/mvnsites/maven-scala-plugin/usage_cc.html)), so it will only work @@ -157,9 +168,9 @@ The debian package can then be found under assembly/target. We added the short c Running only Java 8 tests and nothing else. mvn install -DskipTests -Pjava8-tests - -Java 8 tests are run when `-Pjava8-tests` profile is enabled, they will run in spite of `-DskipTests`. -For these tests to run your system must have a JDK 8 installation. + +Java 8 tests are run when `-Pjava8-tests` profile is enabled, they will run in spite of `-DskipTests`. +For these tests to run your system must have a JDK 8 installation. If you have JDK 8 installed but it is not the system default, you can set JAVA_HOME to point to JDK 8 before running the tests. # Building for PySpark on YARN @@ -171,7 +182,7 @@ then ship it over to the cluster. We are investigating the exact cause for this. # Packaging without Hadoop Dependencies for YARN -The assembly jar produced by `mvn package` will, by default, include all of Spark's dependencies, including Hadoop and some of its ecosystem projects. On YARN deployments, this causes multiple versions of these to appear on executor classpaths: the version packaged in the Spark assembly and the version on each node, included with yarn.application.classpath. The `hadoop-provided` profile builds the assembly without including Hadoop-ecosystem projects, like ZooKeeper and Hadoop itself. +The assembly jar produced by `mvn package` will, by default, include all of Spark's dependencies, including Hadoop and some of its ecosystem projects. On YARN deployments, this causes multiple versions of these to appear on executor classpaths: the version packaged in the Spark assembly and the version on each node, included with yarn.application.classpath. The `hadoop-provided` profile builds the assembly without including Hadoop-ecosystem projects, like ZooKeeper and Hadoop itself. # Building with SBT @@ -182,22 +193,22 @@ compilation. More advanced developers may wish to use SBT. The SBT build is derived from the Maven POM files, and so the same Maven profiles and variables can be set to control the SBT build. For example: - sbt/sbt -Pyarn -Phadoop-2.3 assembly + build/sbt -Pyarn -Phadoop-2.3 assembly # Testing with SBT -Some of the tests require Spark to be packaged first, so always run `sbt/sbt assembly` the first time. The following is an example of a correct (build, test) sequence: +Some of the tests require Spark to be packaged first, so always run `build/sbt assembly` the first time. The following is an example of a correct (build, test) sequence: - sbt/sbt -Pyarn -Phadoop-2.3 -Phive -Phive-thriftserver assembly - sbt/sbt -Pyarn -Phadoop-2.3 -Phive -Phive-thriftserver test + build/sbt -Pyarn -Phadoop-2.3 -Phive -Phive-thriftserver assembly + build/sbt -Pyarn -Phadoop-2.3 -Phive -Phive-thriftserver test To run only a specific test suite as follows: - sbt/sbt -Pyarn -Phadoop-2.3 -Phive -Phive-thriftserver "test-only org.apache.spark.repl.ReplSuite" + build/sbt -Pyarn -Phadoop-2.3 -Phive -Phive-thriftserver "test-only org.apache.spark.repl.ReplSuite" To run test suites of a specific sub project as follows: - sbt/sbt -Pyarn -Phadoop-2.3 -Phive -Phive-thriftserver core/test + build/sbt -Pyarn -Phadoop-2.3 -Phive -Phive-thriftserver core/test # Speeding up Compilation with Zinc @@ -206,3 +217,9 @@ compiler. When run locally as a background process, it speeds up builds of Scala like Spark. Developers who regularly recompile Spark with Maven will be the most interested in Zinc. The project site gives instructions for building and running `zinc`; OS X users can install it using `brew install zinc`. + +If using the `build/mvn` package `zinc` will automatically be downloaded and leveraged for all +builds. This process will auto-start after the first time `build/mvn` is called and bind to port +3030 unless the `ZINC_PORT` environment variable is set. The `zinc` process can subsequently be +shut down at any time by running `build/zinc-/bin/zinc -shutdown` and will automatically +restart whenever `build/mvn` is called. diff --git a/docs/configuration.md b/docs/configuration.md index 2cc013c47fdbb..fa9d311f85068 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -452,7 +452,7 @@ Apart from these, the following properties are also available, and may be useful spark.ui.retainedJobs 1000 - How many stages the Spark UI and status APIs remember before garbage + How many jobs the Spark UI and status APIs remember before garbage collecting. diff --git a/docs/hadoop-third-party-distributions.md b/docs/hadoop-third-party-distributions.md index dd73e9dc54440..87dcc58feb494 100644 --- a/docs/hadoop-third-party-distributions.md +++ b/docs/hadoop-third-party-distributions.md @@ -18,7 +18,7 @@ see the guide on [building with maven](building-spark.html#specifying-the-hadoop The table below lists the corresponding `hadoop.version` code for each CDH/HDP release. Note that some Hadoop releases are binary compatible across client versions. This means the pre-built Spark -distribution may "just work" without you needing to compile. That said, we recommend compiling with +distribution may "just work" without you needing to compile. That said, we recommend compiling with the _exact_ Hadoop version you are running to avoid any compatibility errors. @@ -50,7 +50,7 @@ the _exact_ Hadoop version you are running to avoid any compatibility errors. In SBT, the equivalent can be achieved by setting the the `hadoop.version` property: - sbt/sbt -Dhadoop.version=1.0.4 assembly + build/sbt -Dhadoop.version=1.0.4 assembly # Linking Applications to the Hadoop Version @@ -98,11 +98,11 @@ Spark can run in a variety of deployment modes: * Using dedicated set of Spark nodes in your cluster. These nodes should be co-located with your Hadoop installation. -* Running on the same nodes as an existing Hadoop installation, with a fixed amount memory and +* Running on the same nodes as an existing Hadoop installation, with a fixed amount memory and cores dedicated to Spark on each node. * Run Spark alongside Hadoop using a cluster resource manager, such as YARN or Mesos. -These options are identical for those using CDH and HDP. +These options are identical for those using CDH and HDP. # Inheriting Cluster Configuration @@ -116,5 +116,5 @@ The location of these configuration files varies across CDH and HDP versions, bu a common location is inside of `/etc/hadoop/conf`. Some tools, such as Cloudera Manager, create configurations on-the-fly, but offer a mechanisms to download copies of them. -To make these files visible to Spark, set `HADOOP_CONF_DIR` in `$SPARK_HOME/spark-env.sh` +To make these files visible to Spark, set `HADOOP_CONF_DIR` in `$SPARK_HOME/spark-env.sh` to a location containing the configuration files. diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 2aea8a8aedafc..729045b81a8c0 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -831,13 +831,10 @@ turning on some experimental options. ## Caching Data In Memory -Spark SQL can cache tables using an in-memory columnar format by calling `sqlContext.cacheTable("tableName")`. +Spark SQL can cache tables using an in-memory columnar format by calling `sqlContext.cacheTable("tableName")` or `schemaRDD.cache()`. Then Spark SQL will scan only required columns and will automatically tune compression to minimize memory usage and GC pressure. You can call `sqlContext.uncacheTable("tableName")` to remove the table from memory. -Note that if you call `schemaRDD.cache()` rather than `sqlContext.cacheTable(...)`, tables will _not_ be cached using -the in-memory columnar format, and therefore `sqlContext.cacheTable(...)` is strongly recommended for this use case. - Configuration of in-memory caching can be done using the `setConf` method on SQLContext or by running `SET key=value` commands using SQL. @@ -1010,12 +1007,11 @@ let user control table caching explicitly: CACHE TABLE logs_last_month; UNCACHE TABLE logs_last_month; -**NOTE:** `CACHE TABLE tbl` is lazy, similar to `.cache` on an RDD. This command only marks `tbl` to ensure that -partitions are cached when calculated but doesn't actually cache it until a query that touches `tbl` is executed. -To force the table to be cached, you may simply count the table immediately after executing `CACHE TABLE`: +**NOTE:** `CACHE TABLE tbl` is now __eager__ by default not __lazy__. Don’t need to trigger cache materialization manually anymore. - CACHE TABLE logs_last_month; - SELECT COUNT(1) FROM logs_last_month; +Spark SQL newly introduced a statement to let user control table caching whether or not lazy since Spark 1.2.0: + + CACHE [LAZY] TABLE [AS SELECT] ... Several caching related features are not supported yet: diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index 1ac5b9e863ad4..01450efe35e55 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -75,7 +75,7 @@ main entry point for all streaming functionality. We create a local StreamingCon {% highlight scala %} import org.apache.spark._ import org.apache.spark.streaming._ -import org.apache.spark.streaming.StreamingContext._ +import org.apache.spark.streaming.StreamingContext._ // not necessary in Spark 1.3+ // Create a local StreamingContext with two working thread and batch interval of 1 second. // The master requires 2 cores to prevent from a starvation scenario. @@ -107,7 +107,7 @@ each line will be split into multiple words and the stream of words is represent `words` DStream. Next, we want to count these words. {% highlight scala %} -import org.apache.spark.streaming.StreamingContext._ +import org.apache.spark.streaming.StreamingContext._ // not necessary in Spark 1.3+ // Count each word in each batch val pairs = words.map(word => (word, 1)) val wordCounts = pairs.reduceByKey(_ + _) diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py index 556d99d1027b8..485eea4f5e683 100755 --- a/ec2/spark_ec2.py +++ b/ec2/spark_ec2.py @@ -39,10 +39,10 @@ from optparse import OptionParser from sys import stderr -DEFAULT_SPARK_VERSION = "1.1.0" +DEFAULT_SPARK_VERSION = "1.2.0" SPARK_EC2_DIR = os.path.dirname(os.path.realpath(__file__)) -MESOS_SPARK_EC2_BRANCH = "v4" +MESOS_SPARK_EC2_BRANCH = "branch-1.3" # A URL prefix from which to fetch AMI information AMI_PREFIX = "https://raw.github.com/mesos/spark-ec2/{b}/ami-list".format(b=MESOS_SPARK_EC2_BRANCH) @@ -251,10 +251,13 @@ def get_spark_shark_version(opts): "0.8.1": "0.8.1", "0.9.0": "0.9.0", "0.9.1": "0.9.1", + # These are dummy versions (no Shark versions after this) "1.0.0": "1.0.0", "1.0.1": "1.0.1", "1.0.2": "1.0.2", "1.1.0": "1.1.0", + "1.1.1": "1.1.1", + "1.2.0": "1.2.0", } version = opts.spark_version.replace("v", "") if version not in spark_shark_map: diff --git a/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala b/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala index c9e1511278ede..2adc63f7ff30e 100644 --- a/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala +++ b/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala @@ -22,7 +22,6 @@ import java.util.Properties import kafka.producer._ import org.apache.spark.streaming._ -import org.apache.spark.streaming.StreamingContext._ import org.apache.spark.streaming.kafka._ import org.apache.spark.SparkConf diff --git a/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala b/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala index adecd934358c4..1b53f3edbe92e 100644 --- a/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala @@ -28,11 +28,9 @@ object BroadcastTest { val bcName = if (args.length > 2) args(2) else "Http" val blockSize = if (args.length > 3) args(3) else "4096" - System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast." + bcName + - "BroadcastFactory") - System.setProperty("spark.broadcast.blockSize", blockSize) val sparkConf = new SparkConf().setAppName("Broadcast Test") - + .set("spark.broadcast.factory", s"org.apache.spark.broadcast.${bcName}BroaddcastFactory") + .set("spark.broadcast.blockSize", blockSize) val sc = new SparkContext(sparkConf) val slices = if (args.length > 0) args(0).toInt else 2 diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGmmEM.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGmmEM.scala new file mode 100644 index 0000000000000..948c350953e27 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGmmEM.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib + +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.mllib.clustering.GaussianMixtureEM +import org.apache.spark.mllib.linalg.Vectors + +/** + * An example Gaussian Mixture Model EM app. Run with + * {{{ + * ./bin/run-example org.apache.spark.examples.mllib.DenseGmmEM + * }}} + * If you use it as a template to create your own app, please use `spark-submit` to submit your app. + */ +object DenseGmmEM { + def main(args: Array[String]): Unit = { + if (args.length < 3) { + println("usage: DenseGmmEM [maxIterations]") + } else { + val maxIterations = if (args.length > 3) args(3).toInt else 100 + run(args(0), args(1).toInt, args(2).toDouble, maxIterations) + } + } + + private def run(inputFile: String, k: Int, convergenceTol: Double, maxIterations: Int) { + val conf = new SparkConf().setAppName("Gaussian Mixture Model EM example") + val ctx = new SparkContext(conf) + + val data = ctx.textFile(inputFile).map { line => + Vectors.dense(line.trim.split(' ').map(_.toDouble)) + }.cache() + + val clusters = new GaussianMixtureEM() + .setK(k) + .setConvergenceTol(convergenceTol) + .setMaxIterations(maxIterations) + .run(data) + + for (i <- 0 until clusters.k) { + println("weight=%f\nmu=%s\nsigma=\n%s\n" format + (clusters.weight(i), clusters.mu(i), clusters.sigma(i))) + } + + println("Cluster labels (first <= 100):") + val clusterLabels = clusters.predict(data) + clusterLabels.take(100).foreach { x => + print(" " + x) + } + println() + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/CustomReceiver.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/CustomReceiver.scala index 6bb659fbd8be8..30269a7ccae97 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/CustomReceiver.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/CustomReceiver.scala @@ -23,7 +23,6 @@ import java.net.Socket import org.apache.spark.{SparkConf, Logging} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.{Seconds, StreamingContext} -import org.apache.spark.streaming.StreamingContext._ import org.apache.spark.streaming.receiver.Receiver /** diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/HdfsWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/HdfsWordCount.scala index 6c24bc3ad09e0..4b4667fec44e6 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/HdfsWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/HdfsWordCount.scala @@ -19,7 +19,6 @@ package org.apache.spark.examples.streaming import org.apache.spark.SparkConf import org.apache.spark.streaming.{Seconds, StreamingContext} -import org.apache.spark.streaming.StreamingContext._ /** * Counts words in new text files created in the given directory diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/MQTTWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/MQTTWordCount.scala index e4283e04a1b11..6ff0c47793a25 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/MQTTWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/MQTTWordCount.scala @@ -22,7 +22,6 @@ import org.eclipse.paho.client.mqttv3.persist.MqttDefaultFilePersistence import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.{Seconds, StreamingContext} -import org.apache.spark.streaming.StreamingContext._ import org.apache.spark.streaming.mqtt._ import org.apache.spark.SparkConf diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/NetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/NetworkWordCount.scala index ae0a08c6cdb1a..2cd8073dada14 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/NetworkWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/NetworkWordCount.scala @@ -19,7 +19,6 @@ package org.apache.spark.examples.streaming import org.apache.spark.SparkConf import org.apache.spark.streaming.{Seconds, StreamingContext} -import org.apache.spark.streaming.StreamingContext._ import org.apache.spark.storage.StorageLevel /** diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/QueueStream.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/QueueStream.scala index 4caa90659111a..13ba9a43ec3c9 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/QueueStream.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/QueueStream.scala @@ -22,7 +22,6 @@ import scala.collection.mutable.SynchronizedQueue import org.apache.spark.SparkConf import org.apache.spark.rdd.RDD import org.apache.spark.streaming.{Seconds, StreamingContext} -import org.apache.spark.streaming.StreamingContext._ object QueueStream { diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala index 19427e629f76d..c3a05c89d817e 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala @@ -25,7 +25,6 @@ import com.google.common.io.Files import org.apache.spark.SparkConf import org.apache.spark.rdd.RDD import org.apache.spark.streaming.{Time, Seconds, StreamingContext} -import org.apache.spark.streaming.StreamingContext._ import org.apache.spark.util.IntParam /** diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala index ed186ea5650c4..345d0bc441351 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala @@ -20,7 +20,6 @@ package org.apache.spark.examples.streaming import org.apache.spark.SparkConf import org.apache.spark.HashPartitioner import org.apache.spark.streaming._ -import org.apache.spark.streaming.StreamingContext._ /** * Counts words cumulatively in UTF8 encoded, '\n' delimited text received from the network every diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdCMS.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdCMS.scala index 683752ac96241..62f49530edb12 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdCMS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdCMS.scala @@ -23,7 +23,6 @@ import org.apache.spark.SparkConf import org.apache.spark.SparkContext._ import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.{Seconds, StreamingContext} -import org.apache.spark.streaming.StreamingContext._ import org.apache.spark.streaming.twitter._ // scalastyle:off diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterPopularTags.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterPopularTags.scala index f55d23ab3924b..f253d75b279f7 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterPopularTags.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterPopularTags.scala @@ -18,7 +18,6 @@ package org.apache.spark.examples.streaming import org.apache.spark.streaming.{Seconds, StreamingContext} -import StreamingContext._ import org.apache.spark.SparkContext._ import org.apache.spark.streaming.twitter._ import org.apache.spark.SparkConf diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/ZeroMQWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/ZeroMQWordCount.scala index 79905af381a12..6510c70bd1866 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/ZeroMQWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/ZeroMQWordCount.scala @@ -24,7 +24,6 @@ import akka.zeromq.Subscribe import akka.util.ByteString import org.apache.spark.streaming.{Seconds, StreamingContext} -import org.apache.spark.streaming.StreamingContext._ import org.apache.spark.streaming.zeromq._ import scala.language.implicitConversions diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala index 55226c0a6df60..fbacaee98690f 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala @@ -19,7 +19,6 @@ package org.apache.spark.examples.streaming.clickstream import org.apache.spark.SparkContext._ import org.apache.spark.streaming.{Seconds, StreamingContext} -import org.apache.spark.streaming.StreamingContext._ import org.apache.spark.examples.streaming.StreamingExamples // scalastyle:off /** Analyses a streaming dataset of web page views. This class demonstrates several types of diff --git a/external/flume/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java b/external/flume/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java index 6e1f01900071b..1e24da7f5f60c 100644 --- a/external/flume/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java +++ b/external/flume/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java @@ -17,6 +17,7 @@ package org.apache.spark.streaming; +import org.apache.spark.SparkConf; import org.apache.spark.streaming.api.java.JavaStreamingContext; import org.junit.After; import org.junit.Before; @@ -27,8 +28,11 @@ public abstract class LocalJavaStreamingContext { @Before public void setUp() { - System.setProperty("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock"); - ssc = new JavaStreamingContext("local[2]", "test", new Duration(1000)); + SparkConf conf = new SparkConf() + .setMaster("local[2]") + .setAppName("test") + .set("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock"); + ssc = new JavaStreamingContext(conf, new Duration(1000)); ssc.checkpoint("checkpoint"); } diff --git a/external/mqtt/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java b/external/mqtt/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java index 6e1f01900071b..1e24da7f5f60c 100644 --- a/external/mqtt/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java +++ b/external/mqtt/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java @@ -17,6 +17,7 @@ package org.apache.spark.streaming; +import org.apache.spark.SparkConf; import org.apache.spark.streaming.api.java.JavaStreamingContext; import org.junit.After; import org.junit.Before; @@ -27,8 +28,11 @@ public abstract class LocalJavaStreamingContext { @Before public void setUp() { - System.setProperty("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock"); - ssc = new JavaStreamingContext("local[2]", "test", new Duration(1000)); + SparkConf conf = new SparkConf() + .setMaster("local[2]") + .setAppName("test") + .set("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock"); + ssc = new JavaStreamingContext(conf, new Duration(1000)); ssc.checkpoint("checkpoint"); } diff --git a/external/twitter/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java b/external/twitter/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java index 6e1f01900071b..1e24da7f5f60c 100644 --- a/external/twitter/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java +++ b/external/twitter/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java @@ -17,6 +17,7 @@ package org.apache.spark.streaming; +import org.apache.spark.SparkConf; import org.apache.spark.streaming.api.java.JavaStreamingContext; import org.junit.After; import org.junit.Before; @@ -27,8 +28,11 @@ public abstract class LocalJavaStreamingContext { @Before public void setUp() { - System.setProperty("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock"); - ssc = new JavaStreamingContext("local[2]", "test", new Duration(1000)); + SparkConf conf = new SparkConf() + .setMaster("local[2]") + .setAppName("test") + .set("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock"); + ssc = new JavaStreamingContext(conf, new Duration(1000)); ssc.checkpoint("checkpoint"); } diff --git a/external/zeromq/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java b/external/zeromq/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java index 6e1f01900071b..1e24da7f5f60c 100644 --- a/external/zeromq/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java +++ b/external/zeromq/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java @@ -17,6 +17,7 @@ package org.apache.spark.streaming; +import org.apache.spark.SparkConf; import org.apache.spark.streaming.api.java.JavaStreamingContext; import org.junit.After; import org.junit.Before; @@ -27,8 +28,11 @@ public abstract class LocalJavaStreamingContext { @Before public void setUp() { - System.setProperty("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock"); - ssc = new JavaStreamingContext("local[2]", "test", new Duration(1000)); + SparkConf conf = new SparkConf() + .setMaster("local[2]") + .setAppName("test") + .set("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock"); + ssc = new JavaStreamingContext(conf, new Duration(1000)); ssc.checkpoint("checkpoint"); } diff --git a/extras/java8-tests/README.md b/extras/java8-tests/README.md index e95b73ac7702a..dc9e87f2eeb92 100644 --- a/extras/java8-tests/README.md +++ b/extras/java8-tests/README.md @@ -8,7 +8,7 @@ to your Java location. The set-up depends a bit on the build system: `-java-home` to the sbt launch script. If a Java 8 JDK is detected sbt will automatically include the Java 8 test project. - `$ JAVA_HOME=/opt/jdk1.8.0/ sbt/sbt clean "test-only org.apache.spark.Java8APISuite"` + `$ JAVA_HOME=/opt/jdk1.8.0/ build/sbt clean "test-only org.apache.spark.Java8APISuite"` * For Maven users, @@ -19,6 +19,6 @@ to your Java location. The set-up depends a bit on the build system: `$ JAVA_HOME=/opt/jdk1.8.0/ mvn clean install -DskipTests` `$ JAVA_HOME=/opt/jdk1.8.0/ mvn test -Pjava8-tests -DwildcardSuites=org.apache.spark.Java8APISuite` - Note that the above command can only be run from project root directory since this module - depends on core and the test-jars of core and streaming. This means an install step is + Note that the above command can only be run from project root directory since this module + depends on core and the test-jars of core and streaming. This means an install step is required to make the test dependencies visible to the Java 8 sub-project. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureEM.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureEM.scala new file mode 100644 index 0000000000000..bdf984aee4dae --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureEM.scala @@ -0,0 +1,241 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.clustering + +import scala.collection.mutable.IndexedSeq + +import breeze.linalg.{DenseVector => BreezeVector, DenseMatrix => BreezeMatrix, diag, Transpose} +import org.apache.spark.rdd.RDD +import org.apache.spark.mllib.linalg.{Matrices, Vector, Vectors} +import org.apache.spark.mllib.stat.impl.MultivariateGaussian +import org.apache.spark.mllib.util.MLUtils + +/** + * This class performs expectation maximization for multivariate Gaussian + * Mixture Models (GMMs). A GMM represents a composite distribution of + * independent Gaussian distributions with associated "mixing" weights + * specifying each's contribution to the composite. + * + * Given a set of sample points, this class will maximize the log-likelihood + * for a mixture of k Gaussians, iterating until the log-likelihood changes by + * less than convergenceTol, or until it has reached the max number of iterations. + * While this process is generally guaranteed to converge, it is not guaranteed + * to find a global optimum. + * + * @param k The number of independent Gaussians in the mixture model + * @param convergenceTol The maximum change in log-likelihood at which convergence + * is considered to have occurred. + * @param maxIterations The maximum number of iterations to perform + */ +class GaussianMixtureEM private ( + private var k: Int, + private var convergenceTol: Double, + private var maxIterations: Int) extends Serializable { + + /** A default instance, 2 Gaussians, 100 iterations, 0.01 log-likelihood threshold */ + def this() = this(2, 0.01, 100) + + // number of samples per cluster to use when initializing Gaussians + private val nSamples = 5 + + // an initializing GMM can be provided rather than using the + // default random starting point + private var initialModel: Option[GaussianMixtureModel] = None + + /** Set the initial GMM starting point, bypassing the random initialization. + * You must call setK() prior to calling this method, and the condition + * (model.k == this.k) must be met; failure will result in an IllegalArgumentException + */ + def setInitialModel(model: GaussianMixtureModel): this.type = { + if (model.k == k) { + initialModel = Some(model) + } else { + throw new IllegalArgumentException("mismatched cluster count (model.k != k)") + } + this + } + + /** Return the user supplied initial GMM, if supplied */ + def getInitialModel: Option[GaussianMixtureModel] = initialModel + + /** Set the number of Gaussians in the mixture model. Default: 2 */ + def setK(k: Int): this.type = { + this.k = k + this + } + + /** Return the number of Gaussians in the mixture model */ + def getK: Int = k + + /** Set the maximum number of iterations to run. Default: 100 */ + def setMaxIterations(maxIterations: Int): this.type = { + this.maxIterations = maxIterations + this + } + + /** Return the maximum number of iterations to run */ + def getMaxIterations: Int = maxIterations + + /** + * Set the largest change in log-likelihood at which convergence is + * considered to have occurred. + */ + def setConvergenceTol(convergenceTol: Double): this.type = { + this.convergenceTol = convergenceTol + this + } + + /** Return the largest change in log-likelihood at which convergence is + * considered to have occurred. + */ + def getConvergenceTol: Double = convergenceTol + + /** Perform expectation maximization */ + def run(data: RDD[Vector]): GaussianMixtureModel = { + val sc = data.sparkContext + + // we will operate on the data as breeze data + val breezeData = data.map(u => u.toBreeze.toDenseVector).cache() + + // Get length of the input vectors + val d = breezeData.first.length + + // Determine initial weights and corresponding Gaussians. + // If the user supplied an initial GMM, we use those values, otherwise + // we start with uniform weights, a random mean from the data, and + // diagonal covariance matrices using component variances + // derived from the samples + val (weights, gaussians) = initialModel match { + case Some(gmm) => (gmm.weight, gmm.mu.zip(gmm.sigma).map { case(mu, sigma) => + new MultivariateGaussian(mu.toBreeze.toDenseVector, sigma.toBreeze.toDenseMatrix) + }) + + case None => { + val samples = breezeData.takeSample(true, k * nSamples, scala.util.Random.nextInt) + (Array.fill(k)(1.0 / k), Array.tabulate(k) { i => + val slice = samples.view(i * nSamples, (i + 1) * nSamples) + new MultivariateGaussian(vectorMean(slice), initCovariance(slice)) + }) + } + } + + var llh = Double.MinValue // current log-likelihood + var llhp = 0.0 // previous log-likelihood + + var iter = 0 + while(iter < maxIterations && Math.abs(llh-llhp) > convergenceTol) { + // create and broadcast curried cluster contribution function + val compute = sc.broadcast(ExpectationSum.add(weights, gaussians)_) + + // aggregate the cluster contribution for all sample points + val sums = breezeData.aggregate(ExpectationSum.zero(k, d))(compute.value, _ += _) + + // Create new distributions based on the partial assignments + // (often referred to as the "M" step in literature) + val sumWeights = sums.weights.sum + var i = 0 + while (i < k) { + val mu = sums.means(i) / sums.weights(i) + val sigma = sums.sigmas(i) / sums.weights(i) - mu * new Transpose(mu) // TODO: Use BLAS.dsyr + weights(i) = sums.weights(i) / sumWeights + gaussians(i) = new MultivariateGaussian(mu, sigma) + i = i + 1 + } + + llhp = llh // current becomes previous + llh = sums.logLikelihood // this is the freshly computed log-likelihood + iter += 1 + } + + // Need to convert the breeze matrices to MLlib matrices + val means = Array.tabulate(k) { i => Vectors.fromBreeze(gaussians(i).mu) } + val sigmas = Array.tabulate(k) { i => Matrices.fromBreeze(gaussians(i).sigma) } + new GaussianMixtureModel(weights, means, sigmas) + } + + /** Average of dense breeze vectors */ + private def vectorMean(x: IndexedSeq[BreezeVector[Double]]): BreezeVector[Double] = { + val v = BreezeVector.zeros[Double](x(0).length) + x.foreach(xi => v += xi) + v / x.length.toDouble + } + + /** + * Construct matrix where diagonal entries are element-wise + * variance of input vectors (computes biased variance) + */ + private def initCovariance(x: IndexedSeq[BreezeVector[Double]]): BreezeMatrix[Double] = { + val mu = vectorMean(x) + val ss = BreezeVector.zeros[Double](x(0).length) + x.map(xi => (xi - mu) :^ 2.0).foreach(u => ss += u) + diag(ss / x.length.toDouble) + } +} + +// companion class to provide zero constructor for ExpectationSum +private object ExpectationSum { + def zero(k: Int, d: Int): ExpectationSum = { + new ExpectationSum(0.0, Array.fill(k)(0.0), + Array.fill(k)(BreezeVector.zeros(d)), Array.fill(k)(BreezeMatrix.zeros(d,d))) + } + + // compute cluster contributions for each input point + // (U, T) => U for aggregation + def add( + weights: Array[Double], + dists: Array[MultivariateGaussian]) + (sums: ExpectationSum, x: BreezeVector[Double]): ExpectationSum = { + val p = weights.zip(dists).map { + case (weight, dist) => MLUtils.EPSILON + weight * dist.pdf(x) + } + val pSum = p.sum + sums.logLikelihood += math.log(pSum) + val xxt = x * new Transpose(x) + var i = 0 + while (i < sums.k) { + p(i) /= pSum + sums.weights(i) += p(i) + sums.means(i) += x * p(i) + sums.sigmas(i) += xxt * p(i) // TODO: use BLAS.dsyr + i = i + 1 + } + sums + } +} + +// Aggregation class for partial expectation results +private class ExpectationSum( + var logLikelihood: Double, + val weights: Array[Double], + val means: Array[BreezeVector[Double]], + val sigmas: Array[BreezeMatrix[Double]]) extends Serializable { + + val k = weights.length + + def +=(x: ExpectationSum): ExpectationSum = { + var i = 0 + while (i < k) { + weights(i) += x.weights(i) + means(i) += x.means(i) + sigmas(i) += x.sigmas(i) + i = i + 1 + } + logLikelihood += x.logLikelihood + this + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala new file mode 100644 index 0000000000000..11a110db1f7ca --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.clustering + +import breeze.linalg.{DenseVector => BreezeVector} + +import org.apache.spark.rdd.RDD +import org.apache.spark.mllib.linalg.{Matrix, Vector} +import org.apache.spark.mllib.stat.impl.MultivariateGaussian +import org.apache.spark.mllib.util.MLUtils + +/** + * Multivariate Gaussian Mixture Model (GMM) consisting of k Gaussians, where points + * are drawn from each Gaussian i=1..k with probability w(i); mu(i) and sigma(i) are + * the respective mean and covariance for each Gaussian distribution i=1..k. + * + * @param weight Weights for each Gaussian distribution in the mixture, where weight(i) is + * the weight for Gaussian i, and weight.sum == 1 + * @param mu Means for each Gaussian in the mixture, where mu(i) is the mean for Gaussian i + * @param sigma Covariance maxtrix for each Gaussian in the mixture, where sigma(i) is the + * covariance matrix for Gaussian i + */ +class GaussianMixtureModel( + val weight: Array[Double], + val mu: Array[Vector], + val sigma: Array[Matrix]) extends Serializable { + + /** Number of gaussians in mixture */ + def k: Int = weight.length + + /** Maps given points to their cluster indices. */ + def predict(points: RDD[Vector]): RDD[Int] = { + val responsibilityMatrix = predictMembership(points, mu, sigma, weight, k) + responsibilityMatrix.map(r => r.indexOf(r.max)) + } + + /** + * Given the input vectors, return the membership value of each vector + * to all mixture components. + */ + def predictMembership( + points: RDD[Vector], + mu: Array[Vector], + sigma: Array[Matrix], + weight: Array[Double], + k: Int): RDD[Array[Double]] = { + val sc = points.sparkContext + val dists = sc.broadcast { + (0 until k).map { i => + new MultivariateGaussian(mu(i).toBreeze.toDenseVector, sigma(i).toBreeze.toDenseMatrix) + }.toArray + } + val weights = sc.broadcast(weight) + points.map { x => + computeSoftAssignments(x.toBreeze.toDenseVector, dists.value, weights.value, k) + } + } + + /** + * Compute the partial assignments for each vector + */ + private def computeSoftAssignments( + pt: BreezeVector[Double], + dists: Array[MultivariateGaussian], + weights: Array[Double], + k: Int): Array[Double] = { + val p = weights.zip(dists).map { + case (weight, dist) => MLUtils.EPSILON + weight * dist.pdf(pt) + } + val pSum = p.sum + for (i <- 0 until k) { + p(i) /= pSum + } + p + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala index 6189dce9b27da..7752c1988fdd1 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala @@ -24,7 +24,6 @@ import org.apache.spark.SparkContext._ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors} import org.apache.spark.rdd.RDD -import org.apache.spark.streaming.StreamingContext._ import org.apache.spark.streaming.dstream.DStream import org.apache.spark.util.Utils import org.apache.spark.util.random.XORShiftRandom diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index 7960f3cab576f..d25a7cd5b439d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -71,7 +71,8 @@ class Word2Vec extends Serializable with Logging { private var numPartitions = 1 private var numIterations = 1 private var seed = Utils.random.nextLong() - + private var minCount = 5 + /** * Sets vector size (default: 100). */ @@ -114,6 +115,15 @@ class Word2Vec extends Serializable with Logging { this } + /** + * Sets minCount, the minimum number of times a token must appear to be included in the word2vec + * model's vocabulary (default: 5). + */ + def setMinCount(minCount: Int): this.type = { + this.minCount = minCount + this + } + private val EXP_TABLE_SIZE = 1000 private val MAX_EXP = 6 private val MAX_CODE_LENGTH = 40 @@ -122,9 +132,6 @@ class Word2Vec extends Serializable with Logging { /** context words from [-window, window] */ private val window = 5 - /** minimum frequency to consider a vocabulary word */ - private val minCount = 5 - private var trainWordsCount = 0 private var vocabSize = 0 private var vocab: Array[VocabWord] = null diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala index 327366a1a3a82..5a7281ec6dc3c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala @@ -17,9 +17,11 @@ package org.apache.spark.mllib.linalg -import java.util.{Random, Arrays} +import java.util.{Arrays, Random} -import breeze.linalg.{Matrix => BM, DenseMatrix => BDM, CSCMatrix => BSM} +import scala.collection.mutable.{ArrayBuilder => MArrayBuilder, HashSet => MHashSet, ArrayBuffer} + +import breeze.linalg.{CSCMatrix => BSM, DenseMatrix => BDM, Matrix => BM} /** * Trait for a local matrix. @@ -80,6 +82,16 @@ sealed trait Matrix extends Serializable { /** A human readable representation of the matrix */ override def toString: String = toBreeze.toString() + + /** Map the values of this matrix using a function. Generates a new matrix. Performs the + * function on only the backing array. For example, an operation such as addition or + * subtraction will only be performed on the non-zero values in a `SparseMatrix`. */ + private[mllib] def map(f: Double => Double): Matrix + + /** Update all the values of this matrix using the function f. Performed in-place on the + * backing array. For example, an operation such as addition or subtraction will only be + * performed on the non-zero values in a `SparseMatrix`. */ + private[mllib] def update(f: Double => Double): Matrix } /** @@ -123,6 +135,122 @@ class DenseMatrix(val numRows: Int, val numCols: Int, val values: Array[Double]) } override def copy = new DenseMatrix(numRows, numCols, values.clone()) + + private[mllib] def map(f: Double => Double) = new DenseMatrix(numRows, numCols, values.map(f)) + + private[mllib] def update(f: Double => Double): DenseMatrix = { + val len = values.length + var i = 0 + while (i < len) { + values(i) = f(values(i)) + i += 1 + } + this + } + + /** Generate a `SparseMatrix` from the given `DenseMatrix`. */ + def toSparse(): SparseMatrix = { + val spVals: MArrayBuilder[Double] = new MArrayBuilder.ofDouble + val colPtrs: Array[Int] = new Array[Int](numCols + 1) + val rowIndices: MArrayBuilder[Int] = new MArrayBuilder.ofInt + var nnz = 0 + var j = 0 + while (j < numCols) { + var i = 0 + val indStart = j * numRows + while (i < numRows) { + val v = values(indStart + i) + if (v != 0.0) { + rowIndices += i + spVals += v + nnz += 1 + } + i += 1 + } + j += 1 + colPtrs(j) = nnz + } + new SparseMatrix(numRows, numCols, colPtrs, rowIndices.result(), spVals.result()) + } +} + +/** + * Factory methods for [[org.apache.spark.mllib.linalg.DenseMatrix]]. + */ +object DenseMatrix { + + /** + * Generate a `DenseMatrix` consisting of zeros. + * @param numRows number of rows of the matrix + * @param numCols number of columns of the matrix + * @return `DenseMatrix` with size `numRows` x `numCols` and values of zeros + */ + def zeros(numRows: Int, numCols: Int): DenseMatrix = + new DenseMatrix(numRows, numCols, new Array[Double](numRows * numCols)) + + /** + * Generate a `DenseMatrix` consisting of ones. + * @param numRows number of rows of the matrix + * @param numCols number of columns of the matrix + * @return `DenseMatrix` with size `numRows` x `numCols` and values of ones + */ + def ones(numRows: Int, numCols: Int): DenseMatrix = + new DenseMatrix(numRows, numCols, Array.fill(numRows * numCols)(1.0)) + + /** + * Generate an Identity Matrix in `DenseMatrix` format. + * @param n number of rows and columns of the matrix + * @return `DenseMatrix` with size `n` x `n` and values of ones on the diagonal + */ + def eye(n: Int): DenseMatrix = { + val identity = DenseMatrix.zeros(n, n) + var i = 0 + while (i < n) { + identity.update(i, i, 1.0) + i += 1 + } + identity + } + + /** + * Generate a `DenseMatrix` consisting of i.i.d. uniform random numbers. + * @param numRows number of rows of the matrix + * @param numCols number of columns of the matrix + * @param rng a random number generator + * @return `DenseMatrix` with size `numRows` x `numCols` and values in U(0, 1) + */ + def rand(numRows: Int, numCols: Int, rng: Random): DenseMatrix = { + new DenseMatrix(numRows, numCols, Array.fill(numRows * numCols)(rng.nextDouble())) + } + + /** + * Generate a `DenseMatrix` consisting of i.i.d. gaussian random numbers. + * @param numRows number of rows of the matrix + * @param numCols number of columns of the matrix + * @param rng a random number generator + * @return `DenseMatrix` with size `numRows` x `numCols` and values in N(0, 1) + */ + def randn(numRows: Int, numCols: Int, rng: Random): DenseMatrix = { + new DenseMatrix(numRows, numCols, Array.fill(numRows * numCols)(rng.nextGaussian())) + } + + /** + * Generate a diagonal matrix in `DenseMatrix` format from the supplied values. + * @param vector a `Vector` that will form the values on the diagonal of the matrix + * @return Square `DenseMatrix` with size `values.length` x `values.length` and `values` + * on the diagonal + */ + def diag(vector: Vector): DenseMatrix = { + val n = vector.size + val matrix = DenseMatrix.zeros(n, n) + val values = vector.toArray + var i = 0 + while (i < n) { + matrix.update(i, i, values(i)) + i += 1 + } + matrix + } } /** @@ -156,6 +284,8 @@ class SparseMatrix( require(colPtrs.length == numCols + 1, "The length of the column indices should be the " + s"number of columns + 1. Currently, colPointers.length: ${colPtrs.length}, " + s"numCols: $numCols") + require(values.length == colPtrs.last, "The last value of colPtrs must equal the number of " + + s"elements. values.length: ${values.length}, colPtrs.last: ${colPtrs.last}") override def toArray: Array[Double] = { val arr = new Array[Double](numRows * numCols) @@ -188,7 +318,7 @@ class SparseMatrix( private[mllib] def update(i: Int, j: Int, v: Double): Unit = { val ind = index(i, j) - if (ind == -1){ + if (ind == -1) { throw new NoSuchElementException("The given row and column indices correspond to a zero " + "value. Only non-zero elements in Sparse Matrices can be updated.") } else { @@ -197,6 +327,192 @@ class SparseMatrix( } override def copy = new SparseMatrix(numRows, numCols, colPtrs, rowIndices, values.clone()) + + private[mllib] def map(f: Double => Double) = + new SparseMatrix(numRows, numCols, colPtrs, rowIndices, values.map(f)) + + private[mllib] def update(f: Double => Double): SparseMatrix = { + val len = values.length + var i = 0 + while (i < len) { + values(i) = f(values(i)) + i += 1 + } + this + } + + /** Generate a `DenseMatrix` from the given `SparseMatrix`. */ + def toDense(): DenseMatrix = { + new DenseMatrix(numRows, numCols, toArray) + } +} + +/** + * Factory methods for [[org.apache.spark.mllib.linalg.SparseMatrix]]. + */ +object SparseMatrix { + + /** + * Generate a `SparseMatrix` from Coordinate List (COO) format. Input must be an array of + * (i, j, value) tuples. Entries that have duplicate values of i and j are + * added together. Tuples where value is equal to zero will be omitted. + * @param numRows number of rows of the matrix + * @param numCols number of columns of the matrix + * @param entries Array of (i, j, value) tuples + * @return The corresponding `SparseMatrix` + */ + def fromCOO(numRows: Int, numCols: Int, entries: Iterable[(Int, Int, Double)]): SparseMatrix = { + val sortedEntries = entries.toSeq.sortBy(v => (v._2, v._1)) + val numEntries = sortedEntries.size + if (sortedEntries.nonEmpty) { + // Since the entries are sorted by column index, we only need to check the first and the last. + for (col <- Seq(sortedEntries.head._2, sortedEntries.last._2)) { + require(col >= 0 && col < numCols, s"Column index out of range [0, $numCols): $col.") + } + } + val colPtrs = new Array[Int](numCols + 1) + val rowIndices = MArrayBuilder.make[Int] + rowIndices.sizeHint(numEntries) + val values = MArrayBuilder.make[Double] + values.sizeHint(numEntries) + var nnz = 0 + var prevCol = 0 + var prevRow = -1 + var prevVal = 0.0 + // Append a dummy entry to include the last one at the end of the loop. + (sortedEntries.view :+ (numRows, numCols, 1.0)).foreach { case (i, j, v) => + if (v != 0) { + if (i == prevRow && j == prevCol) { + prevVal += v + } else { + if (prevVal != 0) { + require(prevRow >= 0 && prevRow < numRows, + s"Row index out of range [0, $numRows): $prevRow.") + nnz += 1 + rowIndices += prevRow + values += prevVal + } + prevRow = i + prevVal = v + while (prevCol < j) { + colPtrs(prevCol + 1) = nnz + prevCol += 1 + } + } + } + } + new SparseMatrix(numRows, numCols, colPtrs, rowIndices.result(), values.result()) + } + + /** + * Generate an Identity Matrix in `SparseMatrix` format. + * @param n number of rows and columns of the matrix + * @return `SparseMatrix` with size `n` x `n` and values of ones on the diagonal + */ + def speye(n: Int): SparseMatrix = { + new SparseMatrix(n, n, (0 to n).toArray, (0 until n).toArray, Array.fill(n)(1.0)) + } + + /** + * Generates the skeleton of a random `SparseMatrix` with a given random number generator. + * The values of the matrix returned are undefined. + */ + private def genRandMatrix( + numRows: Int, + numCols: Int, + density: Double, + rng: Random): SparseMatrix = { + require(numRows > 0, s"numRows must be greater than 0 but got $numRows") + require(numCols > 0, s"numCols must be greater than 0 but got $numCols") + require(density >= 0.0 && density <= 1.0, + s"density must be a double in the range 0.0 <= d <= 1.0. Currently, density: $density") + val size = numRows.toLong * numCols + val expected = size * density + assert(expected < Int.MaxValue, + "The expected number of nonzeros cannot be greater than Int.MaxValue.") + val nnz = math.ceil(expected).toInt + if (density == 0.0) { + new SparseMatrix(numRows, numCols, new Array[Int](numCols + 1), Array[Int](), Array[Double]()) + } else if (density == 1.0) { + val colPtrs = Array.tabulate(numCols + 1)(j => j * numRows) + val rowIndices = Array.tabulate(size.toInt)(idx => idx % numRows) + new SparseMatrix(numRows, numCols, colPtrs, rowIndices, new Array[Double](numRows * numCols)) + } else if (density < 0.34) { + // draw-by-draw, expected number of iterations is less than 1.5 * nnz + val entries = MHashSet[(Int, Int)]() + while (entries.size < nnz) { + entries += ((rng.nextInt(numRows), rng.nextInt(numCols))) + } + SparseMatrix.fromCOO(numRows, numCols, entries.map(v => (v._1, v._2, 1.0))) + } else { + // selection-rejection method + var idx = 0L + var numSelected = 0 + var j = 0 + val colPtrs = new Array[Int](numCols + 1) + val rowIndices = new Array[Int](nnz) + while (j < numCols && numSelected < nnz) { + var i = 0 + while (i < numRows && numSelected < nnz) { + if (rng.nextDouble() < 1.0 * (nnz - numSelected) / (size - idx)) { + rowIndices(numSelected) = i + numSelected += 1 + } + i += 1 + idx += 1 + } + colPtrs(j + 1) = numSelected + j += 1 + } + new SparseMatrix(numRows, numCols, colPtrs, rowIndices, new Array[Double](nnz)) + } + } + + /** + * Generate a `SparseMatrix` consisting of i.i.d. uniform random numbers. The number of non-zero + * elements equal the ceiling of `numRows` x `numCols` x `density` + * + * @param numRows number of rows of the matrix + * @param numCols number of columns of the matrix + * @param density the desired density for the matrix + * @param rng a random number generator + * @return `SparseMatrix` with size `numRows` x `numCols` and values in U(0, 1) + */ + def sprand(numRows: Int, numCols: Int, density: Double, rng: Random): SparseMatrix = { + val mat = genRandMatrix(numRows, numCols, density, rng) + mat.update(i => rng.nextDouble()) + } + + /** + * Generate a `SparseMatrix` consisting of i.i.d. gaussian random numbers. + * @param numRows number of rows of the matrix + * @param numCols number of columns of the matrix + * @param density the desired density for the matrix + * @param rng a random number generator + * @return `SparseMatrix` with size `numRows` x `numCols` and values in N(0, 1) + */ + def sprandn(numRows: Int, numCols: Int, density: Double, rng: Random): SparseMatrix = { + val mat = genRandMatrix(numRows, numCols, density, rng) + mat.update(i => rng.nextGaussian()) + } + + /** + * Generate a diagonal matrix in `SparseMatrix` format from the supplied values. + * @param vector a `Vector` that will form the values on the diagonal of the matrix + * @return Square `SparseMatrix` with size `values.length` x `values.length` and non-zero + * `values` on the diagonal + */ + def diag(vector: Vector): SparseMatrix = { + val n = vector.size + vector match { + case sVec: SparseVector => + SparseMatrix.fromCOO(n, n, sVec.indices.zip(sVec.values).map(v => (v._1, v._1, v._2))) + case dVec: DenseVector => + val entries = dVec.values.zipWithIndex + val nnzVals = entries.filter(v => v._1 != 0.0) + SparseMatrix.fromCOO(n, n, nnzVals.map(v => (v._2, v._2, v._1))) + } + } } /** @@ -256,72 +572,250 @@ object Matrices { * Generate a `DenseMatrix` consisting of zeros. * @param numRows number of rows of the matrix * @param numCols number of columns of the matrix - * @return `DenseMatrix` with size `numRows` x `numCols` and values of zeros + * @return `Matrix` with size `numRows` x `numCols` and values of zeros */ - def zeros(numRows: Int, numCols: Int): Matrix = - new DenseMatrix(numRows, numCols, new Array[Double](numRows * numCols)) + def zeros(numRows: Int, numCols: Int): Matrix = DenseMatrix.zeros(numRows, numCols) /** * Generate a `DenseMatrix` consisting of ones. * @param numRows number of rows of the matrix * @param numCols number of columns of the matrix - * @return `DenseMatrix` with size `numRows` x `numCols` and values of ones + * @return `Matrix` with size `numRows` x `numCols` and values of ones */ - def ones(numRows: Int, numCols: Int): Matrix = - new DenseMatrix(numRows, numCols, Array.fill(numRows * numCols)(1.0)) + def ones(numRows: Int, numCols: Int): Matrix = DenseMatrix.ones(numRows, numCols) /** - * Generate an Identity Matrix in `DenseMatrix` format. + * Generate a dense Identity Matrix in `Matrix` format. * @param n number of rows and columns of the matrix - * @return `DenseMatrix` with size `n` x `n` and values of ones on the diagonal + * @return `Matrix` with size `n` x `n` and values of ones on the diagonal */ - def eye(n: Int): Matrix = { - val identity = Matrices.zeros(n, n) - var i = 0 - while (i < n){ - identity.update(i, i, 1.0) - i += 1 - } - identity - } + def eye(n: Int): Matrix = DenseMatrix.eye(n) + + /** + * Generate a sparse Identity Matrix in `Matrix` format. + * @param n number of rows and columns of the matrix + * @return `Matrix` with size `n` x `n` and values of ones on the diagonal + */ + def speye(n: Int): Matrix = SparseMatrix.speye(n) /** * Generate a `DenseMatrix` consisting of i.i.d. uniform random numbers. * @param numRows number of rows of the matrix * @param numCols number of columns of the matrix * @param rng a random number generator - * @return `DenseMatrix` with size `numRows` x `numCols` and values in U(0, 1) + * @return `Matrix` with size `numRows` x `numCols` and values in U(0, 1) */ - def rand(numRows: Int, numCols: Int, rng: Random): Matrix = { - new DenseMatrix(numRows, numCols, Array.fill(numRows * numCols)(rng.nextDouble())) - } + def rand(numRows: Int, numCols: Int, rng: Random): Matrix = + DenseMatrix.rand(numRows, numCols, rng) + + /** + * Generate a `SparseMatrix` consisting of i.i.d. gaussian random numbers. + * @param numRows number of rows of the matrix + * @param numCols number of columns of the matrix + * @param density the desired density for the matrix + * @param rng a random number generator + * @return `Matrix` with size `numRows` x `numCols` and values in U(0, 1) + */ + def sprand(numRows: Int, numCols: Int, density: Double, rng: Random): Matrix = + SparseMatrix.sprand(numRows, numCols, density, rng) /** * Generate a `DenseMatrix` consisting of i.i.d. gaussian random numbers. * @param numRows number of rows of the matrix * @param numCols number of columns of the matrix * @param rng a random number generator - * @return `DenseMatrix` with size `numRows` x `numCols` and values in N(0, 1) + * @return `Matrix` with size `numRows` x `numCols` and values in N(0, 1) */ - def randn(numRows: Int, numCols: Int, rng: Random): Matrix = { - new DenseMatrix(numRows, numCols, Array.fill(numRows * numCols)(rng.nextGaussian())) - } + def randn(numRows: Int, numCols: Int, rng: Random): Matrix = + DenseMatrix.randn(numRows, numCols, rng) + + /** + * Generate a `SparseMatrix` consisting of i.i.d. gaussian random numbers. + * @param numRows number of rows of the matrix + * @param numCols number of columns of the matrix + * @param density the desired density for the matrix + * @param rng a random number generator + * @return `Matrix` with size `numRows` x `numCols` and values in N(0, 1) + */ + def sprandn(numRows: Int, numCols: Int, density: Double, rng: Random): Matrix = + SparseMatrix.sprandn(numRows, numCols, density, rng) /** * Generate a diagonal matrix in `DenseMatrix` format from the supplied values. * @param vector a `Vector` tat will form the values on the diagonal of the matrix - * @return Square `DenseMatrix` with size `values.length` x `values.length` and `values` + * @return Square `Matrix` with size `values.length` x `values.length` and `values` * on the diagonal */ - def diag(vector: Vector): Matrix = { - val n = vector.size - val matrix = Matrices.eye(n) - val values = vector.toArray - var i = 0 - while (i < n) { - matrix.update(i, i, values(i)) - i += 1 + def diag(vector: Vector): Matrix = DenseMatrix.diag(vector) + + /** + * Horizontally concatenate a sequence of matrices. The returned matrix will be in the format + * the matrices are supplied in. Supplying a mix of dense and sparse matrices will result in + * a sparse matrix. If the Array is empty, an empty `DenseMatrix` will be returned. + * @param matrices array of matrices + * @return a single `Matrix` composed of the matrices that were horizontally concatenated + */ + def horzcat(matrices: Array[Matrix]): Matrix = { + if (matrices.isEmpty) { + return new DenseMatrix(0, 0, Array[Double]()) + } else if (matrices.size == 1) { + return matrices(0) + } + val numRows = matrices(0).numRows + var hasSparse = false + var numCols = 0 + matrices.foreach { mat => + require(numRows == mat.numRows, "The number of rows of the matrices in this sequence, " + + "don't match!") + mat match { + case sparse: SparseMatrix => hasSparse = true + case dense: DenseMatrix => // empty on purpose + case _ => throw new IllegalArgumentException("Unsupported matrix format. Expected " + + s"SparseMatrix or DenseMatrix. Instead got: ${mat.getClass}") + } + numCols += mat.numCols + } + if (!hasSparse) { + new DenseMatrix(numRows, numCols, matrices.flatMap(_.toArray)) + } else { + var startCol = 0 + val entries: Array[(Int, Int, Double)] = matrices.flatMap { + case spMat: SparseMatrix => + var j = 0 + val colPtrs = spMat.colPtrs + val rowIndices = spMat.rowIndices + val values = spMat.values + val data = new Array[(Int, Int, Double)](values.length) + val nCols = spMat.numCols + while (j < nCols) { + var idx = colPtrs(j) + while (idx < colPtrs(j + 1)) { + val i = rowIndices(idx) + val v = values(idx) + data(idx) = (i, j + startCol, v) + idx += 1 + } + j += 1 + } + startCol += nCols + data + case dnMat: DenseMatrix => + val data = new ArrayBuffer[(Int, Int, Double)]() + var j = 0 + val nCols = dnMat.numCols + val nRows = dnMat.numRows + val values = dnMat.values + while (j < nCols) { + var i = 0 + val indStart = j * nRows + while (i < nRows) { + val v = values(indStart + i) + if (v != 0.0) { + data.append((i, j + startCol, v)) + } + i += 1 + } + j += 1 + } + startCol += nCols + data + } + SparseMatrix.fromCOO(numRows, numCols, entries) + } + } + + /** + * Vertically concatenate a sequence of matrices. The returned matrix will be in the format + * the matrices are supplied in. Supplying a mix of dense and sparse matrices will result in + * a sparse matrix. If the Array is empty, an empty `DenseMatrix` will be returned. + * @param matrices array of matrices + * @return a single `Matrix` composed of the matrices that were vertically concatenated + */ + def vertcat(matrices: Array[Matrix]): Matrix = { + if (matrices.isEmpty) { + return new DenseMatrix(0, 0, Array[Double]()) + } else if (matrices.size == 1) { + return matrices(0) + } + val numCols = matrices(0).numCols + var hasSparse = false + var numRows = 0 + matrices.foreach { mat => + require(numCols == mat.numCols, "The number of rows of the matrices in this sequence, " + + "don't match!") + mat match { + case sparse: SparseMatrix => + hasSparse = true + case dense: DenseMatrix => + case _ => throw new IllegalArgumentException("Unsupported matrix format. Expected " + + s"SparseMatrix or DenseMatrix. Instead got: ${mat.getClass}") + } + numRows += mat.numRows + + } + if (!hasSparse) { + val allValues = new Array[Double](numRows * numCols) + var startRow = 0 + matrices.foreach { mat => + var j = 0 + val nRows = mat.numRows + val values = mat.toArray + while (j < numCols) { + var i = 0 + val indStart = j * numRows + startRow + val subMatStart = j * nRows + while (i < nRows) { + allValues(indStart + i) = values(subMatStart + i) + i += 1 + } + j += 1 + } + startRow += nRows + } + new DenseMatrix(numRows, numCols, allValues) + } else { + var startRow = 0 + val entries: Array[(Int, Int, Double)] = matrices.flatMap { + case spMat: SparseMatrix => + var j = 0 + val colPtrs = spMat.colPtrs + val rowIndices = spMat.rowIndices + val values = spMat.values + val data = new Array[(Int, Int, Double)](values.length) + while (j < numCols) { + var idx = colPtrs(j) + while (idx < colPtrs(j + 1)) { + val i = rowIndices(idx) + val v = values(idx) + data(idx) = (i + startRow, j, v) + idx += 1 + } + j += 1 + } + startRow += spMat.numRows + data + case dnMat: DenseMatrix => + val data = new ArrayBuffer[(Int, Int, Double)]() + var j = 0 + val nCols = dnMat.numCols + val nRows = dnMat.numRows + val values = dnMat.values + while (j < nCols) { + var i = 0 + val indStart = j * nRows + while (i < nRows) { + val v = values(indStart + i) + if (v != 0.0) { + data.append((i + startRow, j, v)) + } + i += 1 + } + j += 1 + } + startRow += nRows + data + } + SparseMatrix.fromCOO(numRows, numCols, entries) } - matrix } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index 47d1a76fa361d..01f3f90577142 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -268,7 +268,7 @@ object Vectors { * @param p norm. * @return norm in L^p^ space. */ - private[spark] def norm(vector: Vector, p: Double): Double = { + def norm(vector: Vector, p: Double): Double = { require(p >= 1.0) val values = vector match { case dv: DenseVector => dv.values diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala index 5c1acca0ec532..36d8cadd2bdd7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala @@ -142,7 +142,7 @@ class IndexedRowMatrix( val mat = BDM.zeros[Double](m, n) rows.collect().foreach { case IndexedRow(rowIndex, vector) => val i = rowIndex.toInt - vector.toBreeze.activeIterator.foreach { case (j, v) => + vector.foreachActive { case (j, v) => mat(i, j) = v } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala index 10a515af88802..a3fca53929ab7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala @@ -588,8 +588,8 @@ class RowMatrix( val n = numCols().toInt val mat = BDM.zeros[Double](m, n) var i = 0 - rows.collect().foreach { v => - v.toBreeze.activeIterator.foreach { case (j, v) => + rows.collect().foreach { vector => + vector.foreachActive { case (j, v) => mat(i, j) = v } i += 1 diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala index f9791c6571782..8ecd5c6ad93c0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala @@ -45,7 +45,7 @@ class LassoModel ( /** * Train a regression model with L1-regularization using Stochastic Gradient Descent. * This solves the l1-regularized least squares regression formulation - * f(weights) = 1/n ||A weights-y||^2 + regParam ||weights||_1 + * f(weights) = 1/2n ||A weights-y||^2 + regParam ||weights||_1 * Here the data matrix has n rows, and the input RDD holds the set of rows of A, each with * its corresponding right hand side label y. * See also the documentation for the precise formulation. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala index c8cad773f5efb..076ba35051c9d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala @@ -45,7 +45,7 @@ class RidgeRegressionModel ( /** * Train a regression model with L2-regularization using Stochastic Gradient Descent. * This solves the l1-regularized least squares regression formulation - * f(weights) = 1/n ||A weights-y||^2 + regParam/2 ||weights||^2 + * f(weights) = 1/2n ||A weights-y||^2 + regParam/2 ||weights||^2 * Here the data matrix has n rows, and the input RDD holds the set of rows of A, each with * its corresponding right hand side label y. * See also the documentation for the precise formulation. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala index 8db0442a7a569..b549b7c475fc3 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala @@ -22,7 +22,6 @@ import scala.reflect.ClassTag import org.apache.spark.Logging import org.apache.spark.annotation.DeveloperApi import org.apache.spark.mllib.linalg.Vector -import org.apache.spark.streaming.StreamingContext._ import org.apache.spark.streaming.dstream.DStream /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/impl/MultivariateGaussian.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/impl/MultivariateGaussian.scala new file mode 100644 index 0000000000000..2eab5d277827d --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/impl/MultivariateGaussian.scala @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.stat.impl + +import breeze.linalg.{DenseVector => DBV, DenseMatrix => DBM, Transpose, det, pinv} + +/** + * Utility class to implement the density function for multivariate Gaussian distribution. + * Breeze provides this functionality, but it requires the Apache Commons Math library, + * so this class is here so-as to not introduce a new dependency in Spark. + */ +private[mllib] class MultivariateGaussian( + val mu: DBV[Double], + val sigma: DBM[Double]) extends Serializable { + private val sigmaInv2 = pinv(sigma) * -0.5 + private val U = math.pow(2.0 * math.Pi, -mu.length / 2.0) * math.pow(det(sigma), -0.5) + + /** Returns density of this multivariate Gaussian at given point, x */ + def pdf(x: DBV[Double]): Double = { + val delta = x - mu + val deltaTranspose = new Transpose(delta) + U * math.exp(deltaTranspose * sigmaInv2 * delta) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 73e7e32c6db31..b3e8ed9af8c51 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -64,13 +64,6 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo val rfModel = rf.run(input) rfModel.trees(0) } - - /** - * Trains a decision tree model over an RDD. This is deprecated because it hides the static - * methods with the same name in Java. - */ - @deprecated("Please use DecisionTree.run instead.", "1.2.0") - def train(input: RDD[LabeledPoint]): DecisionTreeModel = run(input) } object DecisionTree extends Serializable with Logging { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala index b0d05ae33e1b5..da0da0a168c1d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala @@ -39,7 +39,7 @@ import org.apache.spark.streaming.dstream.DStream */ object MLUtils { - private[util] lazy val EPSILON = { + private[mllib] lazy val EPSILON = { var eps = 1.0 while ((1.0 + (eps / 2.0)) != 1.0) { eps /= 2.0 @@ -154,10 +154,12 @@ object MLUtils { def saveAsLibSVMFile(data: RDD[LabeledPoint], dir: String) { // TODO: allow to specify label precision and feature precision. val dataStr = data.map { case LabeledPoint(label, features) => - val featureStrings = features.toBreeze.activeIterator.map { case (i, v) => - s"${i + 1}:$v" + val sb = new StringBuilder(label.toString) + features.foreachActive { case (i, v) => + sb += ' ' + sb ++= s"${i + 1}:$v" } - (Iterator(label) ++ featureStrings).mkString(" ") + sb.mkString } dataStr.saveAsTextFile(dir) } diff --git a/mllib/src/test/java/org/apache/spark/mllib/feature/JavaTfIdfSuite.java b/mllib/src/test/java/org/apache/spark/mllib/feature/JavaTfIdfSuite.java index 064263e02cd11..fbc26167ce66f 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/feature/JavaTfIdfSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/feature/JavaTfIdfSuite.java @@ -49,6 +49,7 @@ public void tearDown() { public void tfIdf() { // The tests are to check Java compatibility. HashingTF tf = new HashingTF(); + @SuppressWarnings("unchecked") JavaRDD> documents = sc.parallelize(Lists.newArrayList( Lists.newArrayList("this is a sentence".split(" ")), Lists.newArrayList("this is another sentence".split(" ")), @@ -68,6 +69,7 @@ public void tfIdf() { public void tfIdfMinimumDocumentFrequency() { // The tests are to check Java compatibility. HashingTF tf = new HashingTF(); + @SuppressWarnings("unchecked") JavaRDD> documents = sc.parallelize(Lists.newArrayList( Lists.newArrayList("this is a sentence".split(" ")), Lists.newArrayList("this is another sentence".split(" ")), diff --git a/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaMatricesSuite.java b/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaMatricesSuite.java new file mode 100644 index 0000000000000..704d484d0b585 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaMatricesSuite.java @@ -0,0 +1,163 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.linalg; + +import static org.junit.Assert.*; +import org.junit.Test; + +import java.io.Serializable; +import java.util.Random; + +public class JavaMatricesSuite implements Serializable { + + @Test + public void randMatrixConstruction() { + Random rng = new Random(24); + Matrix r = Matrices.rand(3, 4, rng); + rng.setSeed(24); + DenseMatrix dr = DenseMatrix.rand(3, 4, rng); + assertArrayEquals(r.toArray(), dr.toArray(), 0.0); + + rng.setSeed(24); + Matrix rn = Matrices.randn(3, 4, rng); + rng.setSeed(24); + DenseMatrix drn = DenseMatrix.randn(3, 4, rng); + assertArrayEquals(rn.toArray(), drn.toArray(), 0.0); + + rng.setSeed(24); + Matrix s = Matrices.sprand(3, 4, 0.5, rng); + rng.setSeed(24); + SparseMatrix sr = SparseMatrix.sprand(3, 4, 0.5, rng); + assertArrayEquals(s.toArray(), sr.toArray(), 0.0); + + rng.setSeed(24); + Matrix sn = Matrices.sprandn(3, 4, 0.5, rng); + rng.setSeed(24); + SparseMatrix srn = SparseMatrix.sprandn(3, 4, 0.5, rng); + assertArrayEquals(sn.toArray(), srn.toArray(), 0.0); + } + + @Test + public void identityMatrixConstruction() { + Matrix r = Matrices.eye(2); + DenseMatrix dr = DenseMatrix.eye(2); + SparseMatrix sr = SparseMatrix.speye(2); + assertArrayEquals(r.toArray(), dr.toArray(), 0.0); + assertArrayEquals(sr.toArray(), dr.toArray(), 0.0); + assertArrayEquals(r.toArray(), new double[]{1.0, 0.0, 0.0, 1.0}, 0.0); + } + + @Test + public void diagonalMatrixConstruction() { + Vector v = Vectors.dense(1.0, 0.0, 2.0); + Vector sv = Vectors.sparse(3, new int[]{0, 2}, new double[]{1.0, 2.0}); + + Matrix m = Matrices.diag(v); + Matrix sm = Matrices.diag(sv); + DenseMatrix d = DenseMatrix.diag(v); + DenseMatrix sd = DenseMatrix.diag(sv); + SparseMatrix s = SparseMatrix.diag(v); + SparseMatrix ss = SparseMatrix.diag(sv); + + assertArrayEquals(m.toArray(), sm.toArray(), 0.0); + assertArrayEquals(d.toArray(), sm.toArray(), 0.0); + assertArrayEquals(d.toArray(), sd.toArray(), 0.0); + assertArrayEquals(sd.toArray(), s.toArray(), 0.0); + assertArrayEquals(s.toArray(), ss.toArray(), 0.0); + assertArrayEquals(s.values(), ss.values(), 0.0); + assert(s.values().length == 2); + assert(ss.values().length == 2); + assert(s.colPtrs().length == 4); + assert(ss.colPtrs().length == 4); + } + + @Test + public void zerosMatrixConstruction() { + Matrix z = Matrices.zeros(2, 2); + Matrix one = Matrices.ones(2, 2); + DenseMatrix dz = DenseMatrix.zeros(2, 2); + DenseMatrix done = DenseMatrix.ones(2, 2); + + assertArrayEquals(z.toArray(), new double[]{0.0, 0.0, 0.0, 0.0}, 0.0); + assertArrayEquals(dz.toArray(), new double[]{0.0, 0.0, 0.0, 0.0}, 0.0); + assertArrayEquals(one.toArray(), new double[]{1.0, 1.0, 1.0, 1.0}, 0.0); + assertArrayEquals(done.toArray(), new double[]{1.0, 1.0, 1.0, 1.0}, 0.0); + } + + @Test + public void sparseDenseConversion() { + int m = 3; + int n = 2; + double[] values = new double[]{1.0, 2.0, 4.0, 5.0}; + double[] allValues = new double[]{1.0, 2.0, 0.0, 0.0, 4.0, 5.0}; + int[] colPtrs = new int[]{0, 2, 4}; + int[] rowIndices = new int[]{0, 1, 1, 2}; + + SparseMatrix spMat1 = new SparseMatrix(m, n, colPtrs, rowIndices, values); + DenseMatrix deMat1 = new DenseMatrix(m, n, allValues); + + SparseMatrix spMat2 = deMat1.toSparse(); + DenseMatrix deMat2 = spMat1.toDense(); + + assertArrayEquals(spMat1.toArray(), spMat2.toArray(), 0.0); + assertArrayEquals(deMat1.toArray(), deMat2.toArray(), 0.0); + } + + @Test + public void concatenateMatrices() { + int m = 3; + int n = 2; + + Random rng = new Random(42); + SparseMatrix spMat1 = SparseMatrix.sprand(m, n, 0.5, rng); + rng.setSeed(42); + DenseMatrix deMat1 = DenseMatrix.rand(m, n, rng); + Matrix deMat2 = Matrices.eye(3); + Matrix spMat2 = Matrices.speye(3); + Matrix deMat3 = Matrices.eye(2); + Matrix spMat3 = Matrices.speye(2); + + Matrix spHorz = Matrices.horzcat(new Matrix[]{spMat1, spMat2}); + Matrix deHorz1 = Matrices.horzcat(new Matrix[]{deMat1, deMat2}); + Matrix deHorz2 = Matrices.horzcat(new Matrix[]{spMat1, deMat2}); + Matrix deHorz3 = Matrices.horzcat(new Matrix[]{deMat1, spMat2}); + + assert(deHorz1.numRows() == 3); + assert(deHorz2.numRows() == 3); + assert(deHorz3.numRows() == 3); + assert(spHorz.numRows() == 3); + assert(deHorz1.numCols() == 5); + assert(deHorz2.numCols() == 5); + assert(deHorz3.numCols() == 5); + assert(spHorz.numCols() == 5); + + Matrix spVert = Matrices.vertcat(new Matrix[]{spMat1, spMat3}); + Matrix deVert1 = Matrices.vertcat(new Matrix[]{deMat1, deMat3}); + Matrix deVert2 = Matrices.vertcat(new Matrix[]{spMat1, deMat3}); + Matrix deVert3 = Matrices.vertcat(new Matrix[]{deMat1, spMat3}); + + assert(deVert1.numRows() == 5); + assert(deVert2.numRows() == 5); + assert(deVert3.numRows() == 5); + assert(spVert.numRows() == 5); + assert(deVert1.numCols() == 2); + assert(deVert2.numCols() == 2); + assert(deVert3.numCols() == 2); + assert(spVert.numCols() == 2); + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/GMMExpectationMaximizationSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/GMMExpectationMaximizationSuite.scala new file mode 100644 index 0000000000000..23feb82874b70 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/GMMExpectationMaximizationSuite.scala @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.clustering + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.linalg.{Vectors, Matrices} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ + +class GMMExpectationMaximizationSuite extends FunSuite with MLlibTestSparkContext { + test("single cluster") { + val data = sc.parallelize(Array( + Vectors.dense(6.0, 9.0), + Vectors.dense(5.0, 10.0), + Vectors.dense(4.0, 11.0) + )) + + // expectations + val Ew = 1.0 + val Emu = Vectors.dense(5.0, 10.0) + val Esigma = Matrices.dense(2, 2, Array(2.0 / 3.0, -2.0 / 3.0, -2.0 / 3.0, 2.0 / 3.0)) + + val gmm = new GaussianMixtureEM().setK(1).run(data) + + assert(gmm.weight(0) ~== Ew absTol 1E-5) + assert(gmm.mu(0) ~== Emu absTol 1E-5) + assert(gmm.sigma(0) ~== Esigma absTol 1E-5) + } + + test("two clusters") { + val data = sc.parallelize(Array( + Vectors.dense(-5.1971), Vectors.dense(-2.5359), Vectors.dense(-3.8220), + Vectors.dense(-5.2211), Vectors.dense(-5.0602), Vectors.dense( 4.7118), + Vectors.dense( 6.8989), Vectors.dense( 3.4592), Vectors.dense( 4.6322), + Vectors.dense( 5.7048), Vectors.dense( 4.6567), Vectors.dense( 5.5026), + Vectors.dense( 4.5605), Vectors.dense( 5.2043), Vectors.dense( 6.2734) + )) + + // we set an initial gaussian to induce expected results + val initialGmm = new GaussianMixtureModel( + Array(0.5, 0.5), + Array(Vectors.dense(-1.0), Vectors.dense(1.0)), + Array(Matrices.dense(1, 1, Array(1.0)), Matrices.dense(1, 1, Array(1.0))) + ) + + val Ew = Array(1.0 / 3.0, 2.0 / 3.0) + val Emu = Array(Vectors.dense(-4.3673), Vectors.dense(5.1604)) + val Esigma = Array(Matrices.dense(1, 1, Array(1.1098)), Matrices.dense(1, 1, Array(0.86644))) + + val gmm = new GaussianMixtureEM() + .setK(2) + .setInitialModel(initialGmm) + .run(data) + + assert(gmm.weight(0) ~== Ew(0) absTol 1E-3) + assert(gmm.weight(1) ~== Ew(1) absTol 1E-3) + assert(gmm.mu(0) ~== Emu(0) absTol 1E-3) + assert(gmm.mu(1) ~== Emu(1) absTol 1E-3) + assert(gmm.sigma(0) ~== Esigma(0) absTol 1E-3) + assert(gmm.sigma(1) ~== Esigma(1) absTol 1E-3) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala index 322a0e9242918..a35d0fe389fdd 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala @@ -43,9 +43,9 @@ class MatricesSuite extends FunSuite { test("sparse matrix construction") { val m = 3 - val n = 2 + val n = 4 val values = Array(1.0, 2.0, 4.0, 5.0) - val colPtrs = Array(0, 2, 4) + val colPtrs = Array(0, 2, 2, 4, 4) val rowIndices = Array(1, 2, 1, 2) val mat = Matrices.sparse(m, n, colPtrs, rowIndices, values).asInstanceOf[SparseMatrix] assert(mat.numRows === m) @@ -53,6 +53,13 @@ class MatricesSuite extends FunSuite { assert(mat.values.eq(values), "should not copy data") assert(mat.colPtrs.eq(colPtrs), "should not copy data") assert(mat.rowIndices.eq(rowIndices), "should not copy data") + + val entries: Array[(Int, Int, Double)] = Array((2, 2, 3.0), (1, 0, 1.0), (2, 0, 2.0), + (1, 2, 2.0), (2, 2, 2.0), (1, 2, 2.0), (0, 0, 0.0)) + + val mat2 = SparseMatrix.fromCOO(m, n, entries) + assert(mat.toBreeze === mat2.toBreeze) + assert(mat2.values.length == 4) } test("sparse matrix construction with wrong number of elements") { @@ -117,6 +124,142 @@ class MatricesSuite extends FunSuite { assert(sparseMat.values(2) === 10.0) } + test("toSparse, toDense") { + val m = 3 + val n = 2 + val values = Array(1.0, 2.0, 4.0, 5.0) + val allValues = Array(1.0, 2.0, 0.0, 0.0, 4.0, 5.0) + val colPtrs = Array(0, 2, 4) + val rowIndices = Array(0, 1, 1, 2) + + val spMat1 = new SparseMatrix(m, n, colPtrs, rowIndices, values) + val deMat1 = new DenseMatrix(m, n, allValues) + + val spMat2 = deMat1.toSparse() + val deMat2 = spMat1.toDense() + + assert(spMat1.toBreeze === spMat2.toBreeze) + assert(deMat1.toBreeze === deMat2.toBreeze) + } + + test("map, update") { + val m = 3 + val n = 2 + val values = Array(1.0, 2.0, 4.0, 5.0) + val allValues = Array(1.0, 2.0, 0.0, 0.0, 4.0, 5.0) + val colPtrs = Array(0, 2, 4) + val rowIndices = Array(0, 1, 1, 2) + + val spMat1 = new SparseMatrix(m, n, colPtrs, rowIndices, values) + val deMat1 = new DenseMatrix(m, n, allValues) + val deMat2 = deMat1.map(_ * 2) + val spMat2 = spMat1.map(_ * 2) + deMat1.update(_ * 2) + spMat1.update(_ * 2) + + assert(spMat1.toArray === spMat2.toArray) + assert(deMat1.toArray === deMat2.toArray) + } + + test("horzcat, vertcat, eye, speye") { + val m = 3 + val n = 2 + val values = Array(1.0, 2.0, 4.0, 5.0) + val allValues = Array(1.0, 2.0, 0.0, 0.0, 4.0, 5.0) + val colPtrs = Array(0, 2, 4) + val rowIndices = Array(0, 1, 1, 2) + + val spMat1 = new SparseMatrix(m, n, colPtrs, rowIndices, values) + val deMat1 = new DenseMatrix(m, n, allValues) + val deMat2 = Matrices.eye(3) + val spMat2 = Matrices.speye(3) + val deMat3 = Matrices.eye(2) + val spMat3 = Matrices.speye(2) + + val spHorz = Matrices.horzcat(Array(spMat1, spMat2)) + val spHorz2 = Matrices.horzcat(Array(spMat1, deMat2)) + val spHorz3 = Matrices.horzcat(Array(deMat1, spMat2)) + val deHorz1 = Matrices.horzcat(Array(deMat1, deMat2)) + + val deHorz2 = Matrices.horzcat(Array[Matrix]()) + + assert(deHorz1.numRows === 3) + assert(spHorz2.numRows === 3) + assert(spHorz3.numRows === 3) + assert(spHorz.numRows === 3) + assert(deHorz1.numCols === 5) + assert(spHorz2.numCols === 5) + assert(spHorz3.numCols === 5) + assert(spHorz.numCols === 5) + assert(deHorz2.numRows === 0) + assert(deHorz2.numCols === 0) + assert(deHorz2.toArray.length === 0) + + assert(deHorz1.toBreeze.toDenseMatrix === spHorz2.toBreeze.toDenseMatrix) + assert(spHorz2.toBreeze === spHorz3.toBreeze) + assert(spHorz(0, 0) === 1.0) + assert(spHorz(2, 1) === 5.0) + assert(spHorz(0, 2) === 1.0) + assert(spHorz(1, 2) === 0.0) + assert(spHorz(1, 3) === 1.0) + assert(spHorz(2, 4) === 1.0) + assert(spHorz(1, 4) === 0.0) + assert(deHorz1(0, 0) === 1.0) + assert(deHorz1(2, 1) === 5.0) + assert(deHorz1(0, 2) === 1.0) + assert(deHorz1(1, 2) == 0.0) + assert(deHorz1(1, 3) === 1.0) + assert(deHorz1(2, 4) === 1.0) + assert(deHorz1(1, 4) === 0.0) + + intercept[IllegalArgumentException] { + Matrices.horzcat(Array(spMat1, spMat3)) + } + + intercept[IllegalArgumentException] { + Matrices.horzcat(Array(deMat1, spMat3)) + } + + val spVert = Matrices.vertcat(Array(spMat1, spMat3)) + val deVert1 = Matrices.vertcat(Array(deMat1, deMat3)) + val spVert2 = Matrices.vertcat(Array(spMat1, deMat3)) + val spVert3 = Matrices.vertcat(Array(deMat1, spMat3)) + val deVert2 = Matrices.vertcat(Array[Matrix]()) + + assert(deVert1.numRows === 5) + assert(spVert2.numRows === 5) + assert(spVert3.numRows === 5) + assert(spVert.numRows === 5) + assert(deVert1.numCols === 2) + assert(spVert2.numCols === 2) + assert(spVert3.numCols === 2) + assert(spVert.numCols === 2) + assert(deVert2.numRows === 0) + assert(deVert2.numCols === 0) + assert(deVert2.toArray.length === 0) + + assert(deVert1.toBreeze.toDenseMatrix === spVert2.toBreeze.toDenseMatrix) + assert(spVert2.toBreeze === spVert3.toBreeze) + assert(spVert(0, 0) === 1.0) + assert(spVert(2, 1) === 5.0) + assert(spVert(3, 0) === 1.0) + assert(spVert(3, 1) === 0.0) + assert(spVert(4, 1) === 1.0) + assert(deVert1(0, 0) === 1.0) + assert(deVert1(2, 1) === 5.0) + assert(deVert1(3, 0) === 1.0) + assert(deVert1(3, 1) === 0.0) + assert(deVert1(4, 1) === 1.0) + + intercept[IllegalArgumentException] { + Matrices.vertcat(Array(spMat1, spMat2)) + } + + intercept[IllegalArgumentException] { + Matrices.vertcat(Array(deMat1, spMat2)) + } + } + test("zeros") { val mat = Matrices.zeros(2, 3).asInstanceOf[DenseMatrix] assert(mat.numRows === 2) @@ -162,4 +305,29 @@ class MatricesSuite extends FunSuite { assert(mat.numCols === 2) assert(mat.values.toSeq === Seq(1.0, 0.0, 0.0, 2.0)) } + + test("sprand") { + val rng = mock[Random] + when(rng.nextInt(4)).thenReturn(0, 1, 1, 3, 2, 2, 0, 1, 3, 0) + when(rng.nextDouble()).thenReturn(1.0, 2.0, 3.0, 4.0, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0) + val mat = SparseMatrix.sprand(4, 4, 0.25, rng) + assert(mat.numRows === 4) + assert(mat.numCols === 4) + assert(mat.rowIndices.toSeq === Seq(3, 0, 2, 1)) + assert(mat.values.toSeq === Seq(1.0, 2.0, 3.0, 4.0)) + val mat2 = SparseMatrix.sprand(2, 3, 1.0, rng) + assert(mat2.rowIndices.toSeq === Seq(0, 1, 0, 1, 0, 1)) + assert(mat2.colPtrs.toSeq === Seq(0, 2, 4, 6)) + } + + test("sprandn") { + val rng = mock[Random] + when(rng.nextInt(4)).thenReturn(0, 1, 1, 3, 2, 2, 0, 1, 3, 0) + when(rng.nextGaussian()).thenReturn(1.0, 2.0, 3.0, 4.0) + val mat = SparseMatrix.sprandn(4, 4, 0.25, rng) + assert(mat.numRows === 4) + assert(mat.numCols === 4) + assert(mat.rowIndices.toSeq === Seq(3, 0, 2, 1)) + assert(mat.values.toSeq === Seq(1.0, 2.0, 3.0, 4.0)) + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtils.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtils.scala index 30b906aaa3ba4..e957fa5d25f4c 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtils.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtils.scala @@ -178,17 +178,17 @@ object TestingUtils { implicit class MatrixWithAlmostEquals(val x: Matrix) { /** - * When the difference of two vectors are within eps, returns true; otherwise, returns false. + * When the difference of two matrices are within eps, returns true; otherwise, returns false. */ def ~=(r: CompareMatrixRightSide): Boolean = r.fun(x, r.y, r.eps) /** - * When the difference of two vectors are within eps, returns false; otherwise, returns true. + * When the difference of two matrices are within eps, returns false; otherwise, returns true. */ def !~=(r: CompareMatrixRightSide): Boolean = !r.fun(x, r.y, r.eps) /** - * Throws exception when the difference of two vectors are NOT within eps; + * Throws exception when the difference of two matrices are NOT within eps; * otherwise, returns true. */ def ~==(r: CompareMatrixRightSide): Boolean = { diff --git a/pom.xml b/pom.xml index e4db1393ba9cf..05f59a9b4140b 100644 --- a/pom.xml +++ b/pom.xml @@ -149,6 +149,7 @@ 2.10 ${scala.version} org.scala-lang + 1.8.8 @@ -819,10 +820,15 @@ - + org.codehaus.jackson jackson-mapper-asl - 1.8.8 + ${jackson.version} + + + org.codehaus.jackson + jackson-core-asl + ${jackson.version} diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index f52074282e1f7..c512b62f6137e 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -15,6 +15,8 @@ * limitations under the License. */ +import java.io.File + import scala.util.Properties import scala.collection.JavaConversions._ @@ -23,7 +25,7 @@ import sbt.Classpaths.publishTask import sbt.Keys._ import sbtunidoc.Plugin.genjavadocSettings import sbtunidoc.Plugin.UnidocKeys.unidocGenjavadocVersion -import com.typesafe.sbt.pom.{PomBuild, SbtPomKeys} +import com.typesafe.sbt.pom.{loadEffectivePom, PomBuild, SbtPomKeys} import net.virtualvoid.sbt.graph.Plugin.graphSettings object BuildCommons { @@ -112,6 +114,17 @@ object SparkBuild extends PomBuild { override val userPropertiesMap = System.getProperties.toMap + // Handle case where hadoop.version is set via profile. + // Needed only because we read back this property in sbt + // when we create the assembly jar. + val pom = loadEffectivePom(new File("pom.xml"), + profiles = profiles, + userProps = userPropertiesMap) + if (System.getProperty("hadoop.version") == null) { + System.setProperty("hadoop.version", + pom.getProperties.get("hadoop.version").asInstanceOf[String]) + } + lazy val MavenCompile = config("m2r") extend(Compile) lazy val publishLocalBoth = TaskKey[Unit]("publish-local", "publish local for m2 and ivy") @@ -297,8 +310,7 @@ object Assembly { // This must match the same name used in maven (see network/yarn/pom.xml) "spark-" + v + "-yarn-shuffle.jar" } else { - mName + "-" + v + "-hadoop" + - Option(System.getProperty("hadoop.version")).getOrElse("1.0.4") + ".jar" + mName + "-" + v + "-hadoop" + System.getProperty("hadoop.version") + ".jar" } }, mergeStrategy in assembly := { diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 9807a84a66f11..0e8b398fc6b97 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -1671,7 +1671,7 @@ def _ssql_ctx(self): except Py4JError as e: raise Exception("You must build Spark with Hive. " "Export 'SPARK_HIVE=true' and run " - "sbt/sbt assembly", e) + "build/sbt assembly", e) def _get_hive_ctx(self): return self._jvm.HiveContext(self._jsc.sc()) diff --git a/sbt/sbt b/sbt/sbt index 0a251d97db95c..41438251f681e 100755 --- a/sbt/sbt +++ b/sbt/sbt @@ -1,111 +1,29 @@ #!/usr/bin/env bash -# When creating new tests for Spark SQL Hive, the HADOOP_CLASSPATH must contain the hive jars so -# that we can run Hive to generate the golden answer. This is not required for normal development -# or testing. -for i in "$HIVE_HOME"/lib/* -do HADOOP_CLASSPATH="$HADOOP_CLASSPATH:$i" -done -export HADOOP_CLASSPATH - -realpath () { -( - TARGET_FILE="$1" - - cd "$(dirname "$TARGET_FILE")" - TARGET_FILE="$(basename "$TARGET_FILE")" - - COUNT=0 - while [ -L "$TARGET_FILE" -a $COUNT -lt 100 ] - do - TARGET_FILE="$(readlink "$TARGET_FILE")" - cd $(dirname "$TARGET_FILE") - TARGET_FILE="$(basename $TARGET_FILE)" - COUNT=$(($COUNT + 1)) - done - - echo "$(pwd -P)/"$TARGET_FILE"" -) -} - -. "$(dirname "$(realpath "$0")")"/sbt-launch-lib.bash - - -declare -r noshare_opts="-Dsbt.global.base=project/.sbtboot -Dsbt.boot.directory=project/.boot -Dsbt.ivy.home=project/.ivy" -declare -r sbt_opts_file=".sbtopts" -declare -r etc_sbt_opts_file="/etc/sbt/sbtopts" - -usage() { - cat < path to global settings/plugins directory (default: ~/.sbt) - -sbt-boot path to shared boot directory (default: ~/.sbt/boot in 0.11 series) - -ivy path to local Ivy repository (default: ~/.ivy2) - -mem set memory options (default: $sbt_mem, which is $(get_mem_opts $sbt_mem)) - -no-share use all local caches; no sharing - -no-global uses global caches, but does not use global ~/.sbt directory. - -jvm-debug Turn on JVM debugging, open at the given port. - -batch Disable interactive mode - - # sbt version (default: from project/build.properties if present, else latest release) - -sbt-version use the specified version of sbt - -sbt-jar use the specified jar as the sbt launcher - -sbt-rc use an RC version of sbt - -sbt-snapshot use a snapshot version of sbt - - # java version (default: java from PATH, currently $(java -version 2>&1 | grep version)) - -java-home alternate JAVA_HOME - - # jvm options and output control - JAVA_OPTS environment variable, if unset uses "$java_opts" - SBT_OPTS environment variable, if unset uses "$default_sbt_opts" - .sbtopts if this file exists in the current directory, it is - prepended to the runner args - /etc/sbt/sbtopts if this file exists, it is prepended to the runner args - -Dkey=val pass -Dkey=val directly to the java runtime - -J-X pass option -X directly to the java runtime - (-J is stripped) - -S-X add -X to sbt's scalacOptions (-S is stripped) - -PmavenProfiles Enable a maven profile for the build. - -In the case of duplicated or conflicting options, the order above -shows precedence: JAVA_OPTS lowest, command line options highest. -EOM -} - -process_my_args () { - while [[ $# -gt 0 ]]; do - case "$1" in - -no-colors) addJava "-Dsbt.log.noformat=true" && shift ;; - -no-share) addJava "$noshare_opts" && shift ;; - -no-global) addJava "-Dsbt.global.base=$(pwd)/project/.sbtboot" && shift ;; - -sbt-boot) require_arg path "$1" "$2" && addJava "-Dsbt.boot.directory=$2" && shift 2 ;; - -sbt-dir) require_arg path "$1" "$2" && addJava "-Dsbt.global.base=$2" && shift 2 ;; - -debug-inc) addJava "-Dxsbt.inc.debug=true" && shift ;; - -batch) exec &2 +echo " Please update references to point to the new location." >&2 +echo "" >&2 +echo " Invoking 'build/sbt $@' now ..." >&2 +echo "" >&2 + +${_DIR}/../build/sbt "$@" diff --git a/sql/README.md b/sql/README.md index c84534da9a3d3..8d2f3cf4283e0 100644 --- a/sql/README.md +++ b/sql/README.md @@ -22,10 +22,10 @@ export HADOOP_HOME="/hadoop-1.0.4" Using the console ================= -An interactive scala console can be invoked by running `sbt/sbt hive/console`. From here you can execute queries and inspect the various stages of query optimization. +An interactive scala console can be invoked by running `build/sbt hive/console`. From here you can execute queries and inspect the various stages of query optimization. ```scala -catalyst$ sbt/sbt hive/console +catalyst$ build/sbt hive/console [info] Starting scala interpreter... import org.apache.spark.sql.catalyst.analysis._ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index d4fc9bbfd3118..f79d4ff444dc0 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -204,20 +204,16 @@ class SqlParser extends AbstractSparkSQLParser { ) protected lazy val sortType: Parser[LogicalPlan => LogicalPlan] = - ( ORDER ~ BY ~> ordering ^^ { case o => l: LogicalPlan => Sort(o, l) } - | SORT ~ BY ~> ordering ^^ { case o => l: LogicalPlan => SortPartitions(o, l) } + ( ORDER ~ BY ~> ordering ^^ { case o => l: LogicalPlan => Sort(o, true, l) } + | SORT ~ BY ~> ordering ^^ { case o => l: LogicalPlan => Sort(o, false, l) } ) protected lazy val ordering: Parser[Seq[SortOrder]] = - ( rep1sep(singleOrder, ",") - | rep1sep(expression, ",") ~ direction.? ^^ { - case exps ~ d => exps.map(SortOrder(_, d.getOrElse(Ascending))) + ( rep1sep(expression ~ direction.? , ",") ^^ { + case exps => exps.map(pair => SortOrder(pair._1, pair._2.getOrElse(Ascending))) } ) - protected lazy val singleOrder: Parser[SortOrder] = - expression ~ direction ^^ { case e ~ o => SortOrder(e, o) } - protected lazy val direction: Parser[SortDirection] = ( ASC ^^^ Ascending | DESC ^^^ Descending diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 1c4088b8438e1..72680f37a0b4d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -246,7 +246,7 @@ class Analyzer(catalog: Catalog, case p: LogicalPlan if !p.childrenResolved => p // If the projection list contains Stars, expand it. - case p@Project(projectList, child) if containsStar(projectList) => + case p @ Project(projectList, child) if containsStar(projectList) => Project( projectList.flatMap { case s: Star => s.expand(child.output, resolver) @@ -310,7 +310,8 @@ class Analyzer(catalog: Catalog, */ object ResolveSortReferences extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { - case s @ Sort(ordering, p @ Project(projectList, child)) if !s.resolved && p.resolved => + case s @ Sort(ordering, global, p @ Project(projectList, child)) + if !s.resolved && p.resolved => val unresolved = ordering.flatMap(_.collect { case UnresolvedAttribute(name) => name }) val resolved = unresolved.flatMap(child.resolve(_, resolver)) val requiredAttributes = AttributeSet(resolved.collect { case a: Attribute => a }) @@ -319,13 +320,14 @@ class Analyzer(catalog: Catalog, if (missingInProject.nonEmpty) { // Add missing attributes and then project them away after the sort. Project(projectList.map(_.toAttribute), - Sort(ordering, + Sort(ordering, global, Project(projectList ++ missingInProject, child))) } else { logDebug(s"Failed to find $missingInProject in ${p.output.mkString(", ")}") s // Nothing we can do here. Return original plan. } - case s @ Sort(ordering, a @ Aggregate(grouping, aggs, child)) if !s.resolved && a.resolved => + case s @ Sort(ordering, global, a @ Aggregate(grouping, aggs, child)) + if !s.resolved && a.resolved => val unresolved = ordering.flatMap(_.collect { case UnresolvedAttribute(name) => name }) // A small hack to create an object that will allow us to resolve any references that // refer to named expressions that are present in the grouping expressions. @@ -340,8 +342,7 @@ class Analyzer(catalog: Catalog, if (missingInAggs.nonEmpty) { // Add missing grouping exprs and then project them away after the sort. Project(a.output, - Sort(ordering, - Aggregate(grouping, aggs ++ missingInAggs, child))) + Sort(ordering, global, Aggregate(grouping, aggs ++ missingInAggs, child))) } else { s // Nothing we can do here. Return original plan. } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index e38114ab3cf25..242f28f670298 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -361,6 +361,22 @@ trait HiveTypeCoercion { DecimalType(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) ) + case LessThan(e1 @ DecimalType.Expression(p1, s1), + e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 => + LessThan(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)) + + case LessThanOrEqual(e1 @ DecimalType.Expression(p1, s1), + e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 => + LessThanOrEqual(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)) + + case GreaterThan(e1 @ DecimalType.Expression(p1, s1), + e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 => + GreaterThan(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)) + + case GreaterThanOrEqual(e1 @ DecimalType.Expression(p1, s1), + e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 => + GreaterThanOrEqual(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)) + // Promote integers inside a binary expression with fixed-precision decimals to decimals, // and fixed-precision decimals in an expression with floats / doubles to doubles case b: BinaryExpression if b.left.dataType != b.right.dataType => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index fb252cdf51534..a14e5b9ef14d0 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -244,9 +244,9 @@ package object dsl { condition: Option[Expression] = None) = Join(logicalPlan, otherPlan, joinType, condition) - def orderBy(sortExprs: SortOrder*) = Sort(sortExprs, logicalPlan) + def orderBy(sortExprs: SortOrder*) = Sort(sortExprs, true, logicalPlan) - def sortBy(sortExprs: SortOrder*) = SortPartitions(sortExprs, logicalPlan) + def sortBy(sortExprs: SortOrder*) = Sort(sortExprs, false, logicalPlan) def groupBy(groupingExprs: Expression*)(aggregateExprs: Expression*) = { val aliasedExprs = aggregateExprs.map { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 94b6fb084d38a..cb5ff67959868 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.expressions -import scala.collection.immutable.HashSet import org.apache.spark.sql.catalyst.analysis.UnresolvedException import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.types.BooleanType @@ -48,6 +47,14 @@ trait PredicateHelper { } } + protected def splitDisjunctivePredicates(condition: Expression): Seq[Expression] = { + condition match { + case Or(cond1, cond2) => + splitDisjunctivePredicates(cond1) ++ splitDisjunctivePredicates(cond2) + case other => other :: Nil + } + } + /** * Returns true if `expr` can be evaluated using only the output of `plan`. This method * can be used to determine when is is acceptable to move expression evaluation within a query diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 806c1394eb151..cd3137980ca43 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -142,16 +142,16 @@ object ColumnPruning extends Rule[LogicalPlan] { case Project(projectList1, Project(projectList2, child)) => // Create a map of Aliases to their values from the child projection. // e.g., 'SELECT ... FROM (SELECT a + b AS c, d ...)' produces Map(c -> Alias(a + b, c)). - val aliasMap = projectList2.collect { - case a @ Alias(e, _) => (a.toAttribute: Expression, a) - }.toMap + val aliasMap = AttributeMap(projectList2.collect { + case a @ Alias(e, _) => (a.toAttribute, a) + }) // Substitute any attributes that are produced by the child projection, so that we safely // eliminate it. // e.g., 'SELECT c + 1 FROM (SELECT a + b AS C ...' produces 'SELECT a + b + 1 ...' // TODO: Fix TransformBase to avoid the cast below. val substitutedProjection = projectList1.map(_.transform { - case a if aliasMap.contains(a) => aliasMap(a) + case a: Attribute if aliasMap.contains(a) => aliasMap(a) }).asInstanceOf[Seq[NamedExpression]] Project(substitutedProjection, child) @@ -294,11 +294,16 @@ object OptimizeIn extends Rule[LogicalPlan] { } /** - * Simplifies boolean expressions where the answer can be determined without evaluating both sides. + * Simplifies boolean expressions: + * + * 1. Simplifies expressions whose answer can be determined without evaluating both sides. + * 2. Eliminates / extracts common factors. + * 3. Removes `Not` operator. + * * Note that this rule can eliminate expressions that might otherwise have been evaluated and thus * is only safe when evaluations of expressions does not result in side effects. */ -object BooleanSimplification extends Rule[LogicalPlan] { +object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressionsUp { case and @ And(left, right) => @@ -307,7 +312,9 @@ object BooleanSimplification extends Rule[LogicalPlan] { case (l, Literal(true, BooleanType)) => l case (Literal(false, BooleanType), _) => Literal(false) case (_, Literal(false, BooleanType)) => Literal(false) - case (_, _) => and + // a && a && a ... => a + case _ if splitConjunctivePredicates(and).distinct.size == 1 => left + case _ => and } case or @ Or(left, right) => @@ -316,7 +323,19 @@ object BooleanSimplification extends Rule[LogicalPlan] { case (_, Literal(true, BooleanType)) => Literal(true) case (Literal(false, BooleanType), r) => r case (l, Literal(false, BooleanType)) => l - case (_, _) => or + // a || a || a ... => a + case _ if splitDisjunctivePredicates(or).distinct.size == 1 => left + // (a && b && c && ...) || (a && b && d && ...) => a && b && (c || d || ...) + case _ => + val lhsSet = splitConjunctivePredicates(left).toSet + val rhsSet = splitConjunctivePredicates(right).toSet + val common = lhsSet.intersect(rhsSet) + + (lhsSet.diff(common).reduceOption(And) ++ rhsSet.diff(common).reduceOption(And)) + .reduceOption(Or) + .map(_ :: common.toList) + .getOrElse(common.toList) + .reduce(And) } case not @ Not(exp) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index a9282b98adfab..0b9f01cbae9ea 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -130,7 +130,16 @@ case class WriteToFile( override def output = child.output } -case class Sort(order: Seq[SortOrder], child: LogicalPlan) extends UnaryNode { +/** + * @param order The ordering expressions + * @param global True means global sorting apply for entire data set, + * False means sorting only apply within the partition. + * @param child Child logical plan + */ +case class Sort( + order: Seq[SortOrder], + global: Boolean, + child: LogicalPlan) extends UnaryNode { override def output = child.output } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala index d5b7d2789a103..3677a6e72e23a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala @@ -49,6 +49,15 @@ class DecimalPrecisionSuite extends FunSuite with BeforeAndAfter { assert(analyzer(plan).schema.fields(0).dataType === expectedType) } + private def checkComparison(expression: Expression, expectedType: DataType): Unit = { + val plan = Project(Alias(expression, "c")() :: Nil, relation) + val comparison = analyzer(plan).collect { + case Project(Alias(e: BinaryComparison, _) :: Nil, _) => e + }.head + assert(comparison.left.dataType === expectedType) + assert(comparison.right.dataType === expectedType) + } + test("basic operations") { checkType(Add(d1, d2), DecimalType(6, 2)) checkType(Subtract(d1, d2), DecimalType(6, 2)) @@ -65,6 +74,14 @@ class DecimalPrecisionSuite extends FunSuite with BeforeAndAfter { checkType(Add(Add(d1, d2), Add(d1, d2)), DecimalType(7, 2)) } + test("Comparison operations") { + checkComparison(LessThan(i, d1), DecimalType.Unlimited) + checkComparison(LessThanOrEqual(d1, d2), DecimalType.Unlimited) + checkComparison(GreaterThan(d2, u), DecimalType.Unlimited) + checkComparison(GreaterThanOrEqual(d1, f), DoubleType) + checkComparison(GreaterThan(d2, d2), DecimalType(5, 2)) + } + test("bringing in primitive types") { checkType(Add(d1, i), DecimalType(12, 1)) checkType(Add(d1, f), DoubleType) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFiltersSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFiltersSuite.scala new file mode 100644 index 0000000000000..906300d8336cb --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFiltersSuite.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.analysis.EliminateAnalysisOperators +import org.apache.spark.sql.catalyst.expressions.{And, Expression, Or} +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{Filter, LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.RuleExecutor + +// For implicit conversions +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ + +class NormalizeFiltersSuite extends PlanTest { + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = Seq( + Batch("AnalysisNodes", Once, + EliminateAnalysisOperators), + Batch("NormalizeFilters", FixedPoint(100), + BooleanSimplification, + SimplifyFilters)) + } + + val relation = LocalRelation('a.int, 'b.int, 'c.string) + + def checkExpression(original: Expression, expected: Expression): Unit = { + val actual = Optimize(relation.where(original)).collect { case f: Filter => f.condition }.head + val result = (actual, expected) match { + case (And(l1, r1), And(l2, r2)) => (l1 == l2 && r1 == r2) || (l1 == r2 && l2 == r1) + case (Or (l1, r1), Or (l2, r2)) => (l1 == l2 && r1 == r2) || (l1 == r2 && l2 == r1) + case (lhs, rhs) => lhs fastEquals rhs + } + + assert(result, s"$actual isn't equivalent to $expected") + } + + test("a && a => a") { + checkExpression('a === 1 && 'a === 1, 'a === 1) + checkExpression('a === 1 && 'a === 1 && 'a === 1, 'a === 1) + } + + test("a || a => a") { + checkExpression('a === 1 || 'a === 1, 'a === 1) + checkExpression('a === 1 || 'a === 1 || 'a === 1, 'a === 1) + } + + test("(a && b) || (a && c) => a && (b || c)") { + checkExpression( + ('a === 1 && 'a < 10) || ('a > 2 && 'a === 1), + ('a === 1) && ('a < 10 || 'a > 2)) + + checkExpression( + ('a < 1 && 'b > 2 && 'c.isNull) || ('a < 1 && 'c === "hello" && 'b > 2), + ('c.isNull || 'c === "hello") && 'a < 1 && 'b > 2) + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UserDefinedType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UserDefinedType.java index b751847b464fd..f0d079d25b5d4 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UserDefinedType.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UserDefinedType.java @@ -35,6 +35,7 @@ protected UserDefinedType() { } public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; + @SuppressWarnings("unchecked") UserDefinedType that = (UserDefinedType) o; return this.sqlType().equals(that.sqlType()); } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala index 856b10f1a8fd8..80787b61ce1bf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala @@ -214,7 +214,7 @@ class SchemaRDD( * @group Query */ def orderBy(sortExprs: SortOrder*): SchemaRDD = - new SchemaRDD(sqlContext, Sort(sortExprs, logicalPlan)) + new SchemaRDD(sqlContext, Sort(sortExprs, true, logicalPlan)) /** * Sorts the results by the given expressions within partition. @@ -227,7 +227,7 @@ class SchemaRDD( * @group Query */ def sortBy(sortExprs: SortOrder*): SchemaRDD = - new SchemaRDD(sqlContext, SortPartitions(sortExprs, logicalPlan)) + new SchemaRDD(sqlContext, Sort(sortExprs, false, logicalPlan)) @deprecated("use limit with integer argument", "1.1.0") def limit(limitExpr: Expression): SchemaRDD = @@ -238,7 +238,6 @@ class SchemaRDD( * {{{ * schemaRDD.limit(10) * }}} - * * @group Query */ def limit(limitNum: Int): SchemaRDD = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 2954d4ce7d2d8..ce878c137e627 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -33,6 +33,12 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { object LeftSemiJoin extends Strategy with PredicateHelper { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case ExtractEquiJoinKeys(LeftSemi, leftKeys, rightKeys, condition, left, right) + if sqlContext.autoBroadcastJoinThreshold > 0 && + right.statistics.sizeInBytes <= sqlContext.autoBroadcastJoinThreshold => + val semiJoin = joins.BroadcastLeftSemiJoinHash( + leftKeys, rightKeys, planLater(left), planLater(right)) + condition.map(Filter(_, semiJoin)).getOrElse(semiJoin) :: Nil // Find left semi joins where at least some predicates can be evaluated by matching join keys case ExtractEquiJoinKeys(LeftSemi, leftKeys, rightKeys, condition, left, right) => val semiJoin = joins.LeftSemiJoinHash( @@ -190,7 +196,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { object TakeOrdered extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case logical.Limit(IntegerLiteral(limit), logical.Sort(order, child)) => + case logical.Limit(IntegerLiteral(limit), logical.Sort(order, true, child)) => execution.TakeOrdered(limit, order, planLater(child)) :: Nil case _ => Nil } @@ -257,15 +263,14 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.Distinct(partial = false, execution.Distinct(partial = true, planLater(child))) :: Nil - case logical.Sort(sortExprs, child) if sqlContext.externalSortEnabled => - execution.ExternalSort(sortExprs, global = true, planLater(child)):: Nil - case logical.Sort(sortExprs, child) => - execution.Sort(sortExprs, global = true, planLater(child)):: Nil - case logical.SortPartitions(sortExprs, child) => // This sort only sorts tuples within a partition. Its requiredDistribution will be // an UnspecifiedDistribution. execution.Sort(sortExprs, global = false, planLater(child)) :: Nil + case logical.Sort(sortExprs, global, child) if sqlContext.externalSortEnabled => + execution.ExternalSort(sortExprs, global, planLater(child)):: Nil + case logical.Sort(sortExprs, global, child) => + execution.Sort(sortExprs, global, planLater(child)):: Nil case logical.Project(projectList, child) => execution.Project(projectList, planLater(child)) :: Nil case logical.Filter(condition, child) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala new file mode 100644 index 0000000000000..2ab064fd0151e --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.catalyst.expressions.{Expression, Row} +import org.apache.spark.sql.catalyst.plans.physical.ClusteredDistribution +import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} + +/** + * :: DeveloperApi :: + * Build the right table's join keys into a HashSet, and iteratively go through the left + * table, to find the if join keys are in the Hash set. + */ +@DeveloperApi +case class BroadcastLeftSemiJoinHash( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + left: SparkPlan, + right: SparkPlan) extends BinaryNode with HashJoin { + + override val buildSide = BuildRight + + override def output = left.output + + override def execute() = { + val buildIter= buildPlan.execute().map(_.copy()).collect().toIterator + val hashSet = new java.util.HashSet[Row]() + var currentRow: Row = null + + // Create a Hash set of buildKeys + while (buildIter.hasNext) { + currentRow = buildIter.next() + val rowKey = buildSideKeyGenerator(currentRow) + if (!rowKey.anyNull) { + val keyExists = hashSet.contains(rowKey) + if (!keyExists) { + hashSet.add(rowKey) + } + } + } + + val broadcastedRelation = sparkContext.broadcast(hashSet) + + streamedPlan.execute().mapPartitions { streamIter => + val joinKeys = streamSideKeyGenerator() + streamIter.filter(current => { + !joinKeys(current).anyNull && broadcastedRelation.value.contains(joinKeys.currentValue) + }) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala index ef3687e692964..9049eb5932b79 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala @@ -130,7 +130,7 @@ private[parquet] object RowReadSupport { private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging { private[parquet] var writer: RecordConsumer = null - private[parquet] var attributes: Seq[Attribute] = null + private[parquet] var attributes: Array[Attribute] = null override def init(configuration: Configuration): WriteSupport.WriteContext = { val origAttributesStr: String = configuration.get(RowWriteSupport.SPARK_ROW_SCHEMA) @@ -138,7 +138,7 @@ private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging { metadata.put(RowReadSupport.SPARK_METADATA_KEY, origAttributesStr) if (attributes == null) { - attributes = ParquetTypesConverter.convertFromString(origAttributesStr) + attributes = ParquetTypesConverter.convertFromString(origAttributesStr).toArray } log.debug(s"write support initialized for requested schema $attributes") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala index 0e6fb57d57bca..97447871a11ee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala @@ -24,8 +24,8 @@ import scala.util.Try import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.mapreduce.Job -import org.apache.hadoop.mapreduce.lib.output.FileOutputCommitter +import parquet.format.converter.ParquetMetadataConverter import parquet.hadoop.{ParquetFileReader, Footer, ParquetFileWriter} import parquet.hadoop.metadata.{ParquetMetadata, FileMetaData} import parquet.hadoop.util.ContextUtil @@ -458,7 +458,7 @@ private[parquet] object ParquetTypesConverter extends Logging { // ... and fallback to "_metadata" if no such file exists (which implies the Parquet file is // empty, thus normally the "_metadata" file is expected to be fairly small). .orElse(children.find(_.getPath.getName == ParquetFileWriter.PARQUET_METADATA_FILE)) - .map(ParquetFileReader.readFooter(conf, _)) + .map(ParquetFileReader.readFooter(conf, _, ParquetMetadataConverter.NO_FILTER)) .getOrElse( throw new IllegalArgumentException(s"Could not find Parquet metadata at path $path")) } diff --git a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaRowSuite.java b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaRowSuite.java index bc5cd66482add..2b5812159d07d 100644 --- a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaRowSuite.java +++ b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaRowSuite.java @@ -141,6 +141,7 @@ public void constructComplexRow() { doubleValue, stringValue, timestampValue, null); // Complex array + @SuppressWarnings("unchecked") List> arrayOfMaps = Arrays.asList(simpleMap); List arrayOfRows = Arrays.asList(simpleStruct); diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala index e40d034ce4dc0..c0b9cf5163120 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala @@ -24,6 +24,8 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.dsl._ import org.apache.spark.sql.test.TestSQLContext._ +import scala.language.postfixOps + class DslQuerySuite extends QueryTest { import org.apache.spark.sql.TestData._ @@ -86,7 +88,7 @@ class DslQuerySuite extends QueryTest { Seq(Seq(6))) } - test("sorting") { + test("global sorting") { checkAnswer( testData2.orderBy('a.asc, 'b.asc), Seq((1,1), (1,2), (2,1), (2,2), (3,1), (3,2))) @@ -120,22 +122,31 @@ class DslQuerySuite extends QueryTest { mapData.collect().sortBy(_.data(1)).reverse.toSeq) } - test("sorting #2") { + test("partition wide sorting") { + // 2 partitions totally, and + // Partition #1 with values: + // (1, 1) + // (1, 2) + // (2, 1) + // Partition #2 with values: + // (2, 2) + // (3, 1) + // (3, 2) checkAnswer( testData2.sortBy('a.asc, 'b.asc), Seq((1,1), (1,2), (2,1), (2,2), (3,1), (3,2))) checkAnswer( testData2.sortBy('a.asc, 'b.desc), - Seq((1,2), (1,1), (2,2), (2,1), (3,2), (3,1))) + Seq((1,2), (1,1), (2,1), (2,2), (3,2), (3,1))) checkAnswer( testData2.sortBy('a.desc, 'b.desc), - Seq((3,2), (3,1), (2,2), (2,1), (1,2), (1,1))) + Seq((2,1), (1,2), (1,1), (3,2), (3,1), (2,2))) checkAnswer( testData2.sortBy('a.desc, 'b.asc), - Seq((3,1), (3,2), (2,1), (2,2), (1,1), (1,2))) + Seq((2,1), (1,1), (1,2), (3,1), (3,2), (2,2))) } test("limit") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 0378fd7e367f0..1a4232dab86e7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -48,6 +48,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { case j: LeftSemiJoinBNL => j case j: CartesianProduct => j case j: BroadcastNestedLoopJoin => j + case j: BroadcastLeftSemiJoinHash => j } assert(operators.size === 1) @@ -382,4 +383,41 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { """.stripMargin), (null, 10) :: Nil) } + + test("broadcasted left semi join operator selection") { + clearCache() + sql("CACHE TABLE testData") + val tmp = autoBroadcastJoinThreshold + + sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=1000000000") + Seq( + ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", + classOf[BroadcastLeftSemiJoinHash]) + ).foreach { + case (query, joinClass) => assertJoin(query, joinClass) + } + + sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=-1") + + Seq( + ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[LeftSemiJoinHash]) + ).foreach { + case (query, joinClass) => assertJoin(query, joinClass) + } + + setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, tmp.toString) + sql("UNCACHE TABLE testData") + } + + test("left semi join") { + val rdd = sql("SELECT * FROM testData2 LEFT SEMI JOIN testData ON key = a") + checkAnswer(rdd, + (1, 1) :: + (1, 2) :: + (2, 1) :: + (2, 2) :: + (3, 1) :: + (3, 2) :: Nil) + + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index ddf4776ecf7ae..add4e218a22ee 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -987,6 +987,13 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { ) } + test("oder by asc by default when not specify ascending and descending") { + checkAnswer( + sql("SELECT a, b FROM testData2 ORDER BY a desc, b"), + Seq((3, 1), (3, 2), (2, 1), (2,2), (1, 1), (1, 2)) + ) + } + test("Supporting relational operator '<=>' in Spark SQL") { val nullCheckData1 = TestData(1,"1") :: TestData(2,null) :: Nil val rdd1 = sparkContext.parallelize((0 to 1).map(i => nullCheckData1(i))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala index bb553a0a1e50c..497897c3c0d4d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala @@ -55,7 +55,7 @@ object TestData { TestData2(2, 1) :: TestData2(2, 2) :: TestData2(3, 1) :: - TestData2(3, 2) :: Nil).toSchemaRDD + TestData2(3, 2) :: Nil, 2).toSchemaRDD testData2.registerTempTable("testData2") case class DecimalData(a: BigDecimal, b: BigDecimal) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala index 82afa31a99a7e..1915c25392f1e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala @@ -105,7 +105,9 @@ class PartitionBatchPruningSuite extends FunSuite with BeforeAndAfterAll with Be test(query) { val schemaRdd = sql(query) - assertResult(expectedQueryResult.toArray, "Wrong query result") { + val queryExecution = schemaRdd.queryExecution + + assertResult(expectedQueryResult.toArray, s"Wrong query result: $queryExecution") { schemaRdd.collect().map(_.head).toArray } @@ -113,8 +115,10 @@ class PartitionBatchPruningSuite extends FunSuite with BeforeAndAfterAll with Be case in: InMemoryColumnarTableScan => (in.readPartitions.value, in.readBatches.value) }.head - assert(readBatches === expectedReadBatches, "Wrong number of read batches") - assert(readPartitions === expectedReadPartitions, "Wrong number of read partitions") + assert(readBatches === expectedReadBatches, s"Wrong number of read batches: $queryExecution") + assert( + readPartitions === expectedReadPartitions, + s"Wrong number of read partitions: $queryExecution") } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala index b17300475b6f6..4c3a04506ce42 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala @@ -28,11 +28,14 @@ import org.apache.spark.sql.{QueryTest, SQLConf, SchemaRDD} /** * A test suite that tests Parquet filter2 API based filter pushdown optimization. * - * Notice that `!(a cmp b)` are always transformed to its negated form `a cmp' b` by the - * `BooleanSimplification` optimization rule whenever possible. As a result, predicate `!(a < 1)` - * results a `GtEq` filter predicate rather than a `Not`. + * NOTE: * - * @todo Add test cases for `IsNull` and `IsNotNull` after merging PR #3367 + * 1. `!(a cmp b)` is always transformed to its negated form `a cmp' b` by the + * `BooleanSimplification` optimization rule whenever possible. As a result, predicate `!(a < 1)` + * results in a `GtEq` filter predicate rather than a `Not`. + * + * 2. `Tuple1(Option(x))` is used together with `AnyVal` types like `Int` to ensure the inferred + * data type is nullable. */ class ParquetFilterSuite extends QueryTest with ParquetTest { val sqlContext = TestSQLContext @@ -85,14 +88,26 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { } test("filter pushdown - boolean") { - withParquetRDD((true :: false :: Nil).map(Tuple1.apply)) { rdd => + withParquetRDD((true :: false :: Nil).map(b => Tuple1.apply(Option(b)))) { rdd => + checkFilterPushdown(rdd, '_1)('_1.isNull, classOf[Eq[java.lang.Boolean]])(Seq.empty[Row]) + checkFilterPushdown(rdd, '_1)('_1.isNotNull, classOf[NotEq[java.lang.Boolean]]) { + Seq(Row(true), Row(false)) + } + checkFilterPushdown(rdd, '_1)('_1 === true, classOf[Eq[java.lang.Boolean]])(true) - checkFilterPushdown(rdd, '_1)('_1 !== true, classOf[Operators.NotEq[java.lang.Boolean]])(false) + checkFilterPushdown(rdd, '_1)('_1 !== true, classOf[Operators.NotEq[java.lang.Boolean]]) { + false + } } } test("filter pushdown - integer") { - withParquetRDD((1 to 4).map(Tuple1.apply)) { rdd => + withParquetRDD((1 to 4).map(i => Tuple1(Option(i)))) { rdd => + checkFilterPushdown(rdd, '_1)('_1.isNull, classOf[Eq[Integer]])(Seq.empty[Row]) + checkFilterPushdown(rdd, '_1)('_1.isNotNull, classOf[NotEq[Integer]]) { + (1 to 4).map(Row.apply(_)) + } + checkFilterPushdown(rdd, '_1)('_1 === 1, classOf[Eq[Integer]])(1) checkFilterPushdown(rdd, '_1)('_1 !== 1, classOf[Operators.NotEq[Integer]]) { (2 to 4).map(Row.apply(_)) @@ -118,7 +133,12 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { } test("filter pushdown - long") { - withParquetRDD((1 to 4).map(i => Tuple1(i.toLong))) { rdd => + withParquetRDD((1 to 4).map(i => Tuple1(Option(i.toLong)))) { rdd => + checkFilterPushdown(rdd, '_1)('_1.isNull, classOf[Eq[java.lang.Long]])(Seq.empty[Row]) + checkFilterPushdown(rdd, '_1)('_1.isNotNull, classOf[NotEq[java.lang.Long]]) { + (1 to 4).map(Row.apply(_)) + } + checkFilterPushdown(rdd, '_1)('_1 === 1, classOf[Eq[java.lang.Long]])(1) checkFilterPushdown(rdd, '_1)('_1 !== 1, classOf[Operators.NotEq[java.lang.Long]]) { (2 to 4).map(Row.apply(_)) @@ -144,7 +164,12 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { } test("filter pushdown - float") { - withParquetRDD((1 to 4).map(i => Tuple1(i.toFloat))) { rdd => + withParquetRDD((1 to 4).map(i => Tuple1(Option(i.toFloat)))) { rdd => + checkFilterPushdown(rdd, '_1)('_1.isNull, classOf[Eq[java.lang.Float]])(Seq.empty[Row]) + checkFilterPushdown(rdd, '_1)('_1.isNotNull, classOf[NotEq[java.lang.Float]]) { + (1 to 4).map(Row.apply(_)) + } + checkFilterPushdown(rdd, '_1)('_1 === 1, classOf[Eq[java.lang.Float]])(1) checkFilterPushdown(rdd, '_1)('_1 !== 1, classOf[Operators.NotEq[java.lang.Float]]) { (2 to 4).map(Row.apply(_)) @@ -170,7 +195,12 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { } test("filter pushdown - double") { - withParquetRDD((1 to 4).map(i => Tuple1(i.toDouble))) { rdd => + withParquetRDD((1 to 4).map(i => Tuple1(Option(i.toDouble)))) { rdd => + checkFilterPushdown(rdd, '_1)('_1.isNull, classOf[Eq[java.lang.Double]])(Seq.empty[Row]) + checkFilterPushdown(rdd, '_1)('_1.isNotNull, classOf[NotEq[java.lang.Double]]) { + (1 to 4).map(Row.apply(_)) + } + checkFilterPushdown(rdd, '_1)('_1 === 1, classOf[Eq[java.lang.Double]])(1) checkFilterPushdown(rdd, '_1)('_1 !== 1, classOf[Operators.NotEq[java.lang.Double]]) { (2 to 4).map(Row.apply(_)) @@ -197,6 +227,11 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { test("filter pushdown - string") { withParquetRDD((1 to 4).map(i => Tuple1(i.toString))) { rdd => + checkFilterPushdown(rdd, '_1)('_1.isNull, classOf[Eq[java.lang.String]])(Seq.empty[Row]) + checkFilterPushdown(rdd, '_1)('_1.isNotNull, classOf[NotEq[java.lang.String]]) { + (1 to 4).map(i => Row.apply(i.toString)) + } + checkFilterPushdown(rdd, '_1)('_1 === "1", classOf[Eq[String]])("1") checkFilterPushdown(rdd, '_1)('_1 !== "1", classOf[Operators.NotEq[String]]) { (2 to 4).map(i => Row.apply(i.toString)) @@ -227,6 +262,11 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { } withParquetRDD((1 to 4).map(i => Tuple1(i.b))) { rdd => + checkBinaryFilterPushdown(rdd, '_1)('_1.isNull, classOf[Eq[java.lang.String]])(Seq.empty[Row]) + checkBinaryFilterPushdown(rdd, '_1)('_1.isNotNull, classOf[NotEq[java.lang.String]]) { + (1 to 4).map(i => Row.apply(i.b)).toSeq + } + checkBinaryFilterPushdown(rdd, '_1)('_1 === 1.b, classOf[Eq[Array[Byte]]])(1.b) checkBinaryFilterPushdown(rdd, '_1)('_1 !== 1.b, classOf[Operators.NotEq[Array[Byte]]]) { (2 to 4).map(i => Row.apply(i.b)).toSeq diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala index 074855389d746..a5fe2e8da2840 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.parquet +import scala.reflect.ClassTag + import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.mapreduce.Job import org.scalatest.{BeforeAndAfterAll, FunSuiteLike} @@ -459,11 +461,17 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA } test("make RecordFilter for simple predicates") { - def checkFilter[T <: FilterPredicate](predicate: Expression, defined: Boolean = true): Unit = { + def checkFilter[T <: FilterPredicate : ClassTag]( + predicate: Expression, + defined: Boolean = true): Unit = { val filter = ParquetFilters.createFilter(predicate) if (defined) { assert(filter.isDefined) - assert(filter.get.isInstanceOf[T]) + val tClass = implicitly[ClassTag[T]].runtimeClass + val filterGet = filter.get + assert( + tClass.isInstance(filterGet), + s"$filterGet of type ${filterGet.getClass} is not an instance of $tClass") } else { assert(filter.isEmpty) } @@ -484,7 +492,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA checkFilter[Operators.And]('a.int === 1 && 'a.int < 4) checkFilter[Operators.Or]('a.int === 1 || 'a.int < 4) - checkFilter[Operators.Not](!('a.int === 1)) + checkFilter[Operators.NotEq[Integer]](!('a.int === 1)) checkFilter('a.int > 'b.int, defined = false) checkFilter(('a.int > 'b.int) && ('a.int > 'b.int), defined = false) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/AbstractSparkSQLDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/AbstractSparkSQLDriver.scala index 6ed8fd2768f95..7a3d76c61c3a1 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/AbstractSparkSQLDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/AbstractSparkSQLDriver.scala @@ -60,7 +60,7 @@ private[hive] abstract class AbstractSparkSQLDriver( } catch { case cause: Throwable => logError(s"Failed in [$command]", cause) - new CommandProcessorResponse(0, ExceptionUtils.getFullStackTrace(cause), null) + new CommandProcessorResponse(1, ExceptionUtils.getFullStackTrace(cause), null) } } diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index 1e44dd239458a..23283fd3fe6b1 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -101,6 +101,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "describe_comment_nonascii", "create_merge_compressed", + "create_view", "create_view_partitioned", "database_location", "database_properties", @@ -110,7 +111,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { // Weird DDL differences result in failures on jenkins. "create_like2", - "create_view_translate", "partitions_json", // This test is totally fine except that it includes wrong queries and expects errors, but error @@ -349,6 +349,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "create_nested_type", "create_skewed_table1", "create_struct_table", + "create_view_translate", "cross_join", "cross_product_check_1", "cross_product_check_2", diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 56fe27a77b838..982e0593fcfd1 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -284,7 +284,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { * Execute the command using Hive and return the results as a sequence. Each element * in the sequence is one row. */ - protected def runHive(cmd: String, maxRows: Int = 1000): Seq[String] = { + protected def runHive(cmd: String, maxRows: Int = 1000): Seq[String] = synchronized { try { val cmd_trimmed: String = cmd.trim() val tokens: Array[String] = cmd_trimmed.split("\\s+") diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 9ac6915768fd1..4b8800b92a0ac 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -680,16 +680,16 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C val withSort = (orderByClause, sortByClause, distributeByClause, clusterByClause) match { case (Some(totalOrdering), None, None, None) => - Sort(totalOrdering.getChildren.map(nodeToSortOrder), withHaving) + Sort(totalOrdering.getChildren.map(nodeToSortOrder), true, withHaving) case (None, Some(perPartitionOrdering), None, None) => - SortPartitions(perPartitionOrdering.getChildren.map(nodeToSortOrder), withHaving) + Sort(perPartitionOrdering.getChildren.map(nodeToSortOrder), false, withHaving) case (None, None, Some(partitionExprs), None) => Repartition(partitionExprs.getChildren.map(nodeToExpr), withHaving) case (None, Some(perPartitionOrdering), Some(partitionExprs), None) => - SortPartitions(perPartitionOrdering.getChildren.map(nodeToSortOrder), + Sort(perPartitionOrdering.getChildren.map(nodeToSortOrder), false, Repartition(partitionExprs.getChildren.map(nodeToExpr), withHaving)) case (None, None, None, Some(clusterExprs)) => - SortPartitions(clusterExprs.getChildren.map(nodeToExpr).map(SortOrder(_, Ascending)), + Sort(clusterExprs.getChildren.map(nodeToExpr).map(SortOrder(_, Ascending)), false, Repartition(clusterExprs.getChildren.map(nodeToExpr), withHaving)) case (None, None, None, None) => withHaving case _ => sys.error("Unsupported set of ordering / distribution clauses.") diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala index 93b6ef9fbc59b..7d863f9d89dae 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala @@ -158,11 +158,6 @@ private[hive] case class HiveGenericUdf(funcWrapper: HiveFunctionWrapper, childr override def foldable = isUDFDeterministic && returnInspector.isInstanceOf[ConstantObjectInspector] - @transient - protected def constantReturnValue = unwrap( - returnInspector.asInstanceOf[ConstantObjectInspector].getWritableConstantValue(), - returnInspector) - @transient protected lazy val deferedObjects = argumentInspectors.map(new DeferredObjectAdapter(_)).toArray[DeferredObject] @@ -171,7 +166,6 @@ private[hive] case class HiveGenericUdf(funcWrapper: HiveFunctionWrapper, childr override def eval(input: Row): Any = { returnInspector // Make sure initialized. - if(foldable) return constantReturnValue var i = 0 while (i < children.length) { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/parquet/FakeParquetSerDe.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/parquet/FakeParquetSerDe.scala index abed299cd957f..2a16c9d1a27c9 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/parquet/FakeParquetSerDe.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/parquet/FakeParquetSerDe.scala @@ -32,7 +32,7 @@ import org.apache.hadoop.io.Writable * when "spark.sql.hive.convertMetastoreParquet" is set to true. */ @deprecated("No code should depend on FakeParquetHiveSerDe as it is only intended as a " + - "placeholder in the Hive MetaStore") + "placeholder in the Hive MetaStore", "1.2.0") class FakeParquetSerDe extends SerDe { override def getObjectInspector: ObjectInspector = new ObjectInspector { override def getCategory: Category = Category.PRIMITIVE diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFListListInt.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFListListInt.java index d2d39a8c4dc28..808e2986d3b77 100644 --- a/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFListListInt.java +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFListListInt.java @@ -23,25 +23,21 @@ public class UDFListListInt extends UDF { /** - * * @param obj - * SQL schema: array> - * Java Type: List> - * @return + * SQL schema: array<struct<x: int, y: int, z: int>> + * Java Type: List<List<Integer>> */ + @SuppressWarnings("unchecked") public long evaluate(Object obj) { if (obj == null) { - return 0l; + return 0L; } - List listList = (List) obj; + List> listList = (List>) obj; long retVal = 0; - for (List aList : listList) { - @SuppressWarnings("unchecked") - List list = (List) aList; - @SuppressWarnings("unchecked") - Integer someInt = (Integer) list.get(1); + for (List aList : listList) { + Number someInt = (Number) aList.get(1); try { - retVal += (long) (someInt.intValue()); + retVal += someInt.longValue(); } catch (NullPointerException e) { System.out.println(e); } diff --git a/sql/hive/src/test/resources/golden/create_view_translate-0-dc7fc9ce5109ef459ee84ccfbb12d2c0 b/sql/hive/src/test/resources/golden/create_view_translate-0-dc7fc9ce5109ef459ee84ccfbb12d2c0 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/create_view_translate-1-3896ae0e680a5fdc01833533b11c07bb b/sql/hive/src/test/resources/golden/create_view_translate-1-3896ae0e680a5fdc01833533b11c07bb new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/create_view_translate-10-7016e1e3a4248564f3d08cddad7ae116 b/sql/hive/src/test/resources/golden/create_view_translate-10-7016e1e3a4248564f3d08cddad7ae116 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/create_view_translate-11-e27c6a59a833dcbc2e5cdb7ff7972828 b/sql/hive/src/test/resources/golden/create_view_translate-11-e27c6a59a833dcbc2e5cdb7ff7972828 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/create_view_translate-2-6b4caec6d7e3a91e61720bbd6b7697f0 b/sql/hive/src/test/resources/golden/create_view_translate-2-6b4caec6d7e3a91e61720bbd6b7697f0 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/create_view_translate-3-30dc3e80e3873af5115e4f5e39078a13 b/sql/hive/src/test/resources/golden/create_view_translate-3-30dc3e80e3873af5115e4f5e39078a13 new file mode 100644 index 0000000000000..cec5f77033aa4 --- /dev/null +++ b/sql/hive/src/test/resources/golden/create_view_translate-3-30dc3e80e3873af5115e4f5e39078a13 @@ -0,0 +1,27 @@ +# col_name data_type comment + +key string + +# Detailed Table Information +Database: default +Owner: animal +CreateTime: Mon Dec 29 00:57:55 PST 2014 +LastAccessTime: UNKNOWN +Protect Mode: None +Retention: 0 +Table Type: VIRTUAL_VIEW +Table Parameters: + transient_lastDdlTime 1419843475 + +# Storage Information +SerDe Library: null +InputFormat: org.apache.hadoop.mapred.SequenceFileInputFormat +OutputFormat: org.apache.hadoop.hive.ql.io.HiveSequenceFileOutputFormat +Compressed: No +Num Buckets: -1 +Bucket Columns: [] +Sort Columns: [] + +# View Information +View Original Text: select cast(key as string) from src +View Expanded Text: select cast(`src`.`key` as string) from `default`.`src` diff --git a/sql/hive/src/test/resources/golden/create_view_translate-4-cefb7530126f9e60cb4a29441d578f23 b/sql/hive/src/test/resources/golden/create_view_translate-4-cefb7530126f9e60cb4a29441d578f23 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/create_view_translate-5-856ea995681b18a543dc0e53b8b43a8e b/sql/hive/src/test/resources/golden/create_view_translate-5-856ea995681b18a543dc0e53b8b43a8e new file mode 100644 index 0000000000000..bf582fc0964a3 --- /dev/null +++ b/sql/hive/src/test/resources/golden/create_view_translate-5-856ea995681b18a543dc0e53b8b43a8e @@ -0,0 +1,32 @@ +# col_name data_type comment + +key int +value string + +# Detailed Table Information +Database: default +Owner: animal +CreateTime: Mon Dec 29 00:57:55 PST 2014 +LastAccessTime: UNKNOWN +Protect Mode: None +Retention: 0 +Table Type: VIRTUAL_VIEW +Table Parameters: + transient_lastDdlTime 1419843475 + +# Storage Information +SerDe Library: null +InputFormat: org.apache.hadoop.mapred.SequenceFileInputFormat +OutputFormat: org.apache.hadoop.hive.ql.io.HiveSequenceFileOutputFormat +Compressed: No +Num Buckets: -1 +Bucket Columns: [] +Sort Columns: [] + +# View Information +View Original Text: select key, value from ( + select key, value from src +) a +View Expanded Text: select key, value from ( + select `src`.`key`, `src`.`value` from `default`.`src` +) `a` diff --git a/sql/hive/src/test/resources/golden/create_view_translate-6-a14cfe3eff322066e61023ec06c7735d b/sql/hive/src/test/resources/golden/create_view_translate-6-a14cfe3eff322066e61023ec06c7735d new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/create_view_translate-7-e947bf2dacc907825df154a4131a3fcc b/sql/hive/src/test/resources/golden/create_view_translate-7-e947bf2dacc907825df154a4131a3fcc new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/create_view_translate-8-b1a99b0beffb0b298aec9233ecc0707f b/sql/hive/src/test/resources/golden/create_view_translate-8-b1a99b0beffb0b298aec9233ecc0707f new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/create_view_translate-9-fc0dc39c4796d917685e0797bc4a9786 b/sql/hive/src/test/resources/golden/create_view_translate-9-fc0dc39c4796d917685e0797bc4a9786 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala index bfe608a51a30b..f90d3607915ae 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.hive import java.sql.Date import java.util +import java.util.{Locale, TimeZone} import org.apache.hadoop.hive.serde2.io.DoubleWritable import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory @@ -63,6 +64,11 @@ class HiveInspectorSuite extends FunSuite with HiveInspectors { .get()) } + // Timezone is fixed to America/Los_Angeles for those timezone sensitive tests (timestamp_*) + TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles")) + // Add Locale setting + Locale.setDefault(Locale.US) + val data = Literal(true) :: Literal(0.asInstanceOf[Byte]) :: @@ -121,11 +127,11 @@ class HiveInspectorSuite extends FunSuite with HiveInspectors { def checkValues(row1: Seq[Any], row2: Seq[Any]): Unit = { row1.zip(row2).map { - case (r1, r2) => checkValues(r1, r2) + case (r1, r2) => checkValue(r1, r2) } } - def checkValues(v1: Any, v2: Any): Unit = { + def checkValue(v1: Any, v2: Any): Unit = { (v1, v2) match { case (r1: Decimal, r2: Decimal) => // Ignore the Decimal precision @@ -195,26 +201,26 @@ class HiveInspectorSuite extends FunSuite with HiveInspectors { }) checkValues(row, unwrap(wrap(row, toInspector(dt)), toInspector(dt)).asInstanceOf[Row]) - checkValues(null, unwrap(wrap(null, toInspector(dt)), toInspector(dt))) + checkValue(null, unwrap(wrap(null, toInspector(dt)), toInspector(dt))) } test("wrap / unwrap Array Type") { val dt = ArrayType(dataTypes(0)) val d = row(0) :: row(0) :: Nil - checkValues(d, unwrap(wrap(d, toInspector(dt)), toInspector(dt))) - checkValues(null, unwrap(wrap(null, toInspector(dt)), toInspector(dt))) - checkValues(d, unwrap(wrap(d, toInspector(Literal(d, dt))), toInspector(Literal(d, dt)))) - checkValues(d, unwrap(wrap(null, toInspector(Literal(d, dt))), toInspector(Literal(d, dt)))) + checkValue(d, unwrap(wrap(d, toInspector(dt)), toInspector(dt))) + checkValue(null, unwrap(wrap(null, toInspector(dt)), toInspector(dt))) + checkValue(d, unwrap(wrap(d, toInspector(Literal(d, dt))), toInspector(Literal(d, dt)))) + checkValue(d, unwrap(wrap(null, toInspector(Literal(d, dt))), toInspector(Literal(d, dt)))) } test("wrap / unwrap Map Type") { val dt = MapType(dataTypes(0), dataTypes(1)) val d = Map(row(0) -> row(1)) - checkValues(d, unwrap(wrap(d, toInspector(dt)), toInspector(dt))) - checkValues(null, unwrap(wrap(null, toInspector(dt)), toInspector(dt))) - checkValues(d, unwrap(wrap(d, toInspector(Literal(d, dt))), toInspector(Literal(d, dt)))) - checkValues(d, unwrap(wrap(null, toInspector(Literal(d, dt))), toInspector(Literal(d, dt)))) + checkValue(d, unwrap(wrap(d, toInspector(dt)), toInspector(dt))) + checkValue(null, unwrap(wrap(null, toInspector(dt)), toInspector(dt))) + checkValue(d, unwrap(wrap(d, toInspector(Literal(d, dt))), toInspector(Literal(d, dt)))) + checkValue(d, unwrap(wrap(null, toInspector(Literal(d, dt))), toInspector(Literal(d, dt)))) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index ff4071d8e2f10..4b6a9308b9811 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -22,7 +22,7 @@ import org.scalatest.BeforeAndAfterAll import scala.reflect.ClassTag import org.apache.spark.sql.{SQLConf, QueryTest} -import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, ShuffledHashJoin} +import org.apache.spark.sql.execution.joins._ import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.hive.execution._ @@ -193,4 +193,52 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { ) } + test("auto converts to broadcast left semi join, by size estimate of a relation") { + val leftSemiJoinQuery = + """SELECT * FROM src a + |left semi JOIN src b ON a.key=86 and a.key = b.key""".stripMargin + val answer = (86, "val_86") :: Nil + + var rdd = sql(leftSemiJoinQuery) + + // Assert src has a size smaller than the threshold. + val sizes = rdd.queryExecution.analyzed.collect { + case r if implicitly[ClassTag[MetastoreRelation]].runtimeClass + .isAssignableFrom(r.getClass) => + r.statistics.sizeInBytes + } + assert(sizes.size === 2 && sizes(1) <= autoBroadcastJoinThreshold + && sizes(0) <= autoBroadcastJoinThreshold, + s"query should contain two relations, each of which has size smaller than autoConvertSize") + + // Using `sparkPlan` because for relevant patterns in HashJoin to be + // matched, other strategies need to be applied. + var bhj = rdd.queryExecution.sparkPlan.collect { + case j: BroadcastLeftSemiJoinHash => j + } + assert(bhj.size === 1, + s"actual query plans do not contain broadcast join: ${rdd.queryExecution}") + + checkAnswer(rdd, answer) // check correctness of output + + TestHive.settings.synchronized { + val tmp = autoBroadcastJoinThreshold + + sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=-1") + rdd = sql(leftSemiJoinQuery) + bhj = rdd.queryExecution.sparkPlan.collect { + case j: BroadcastLeftSemiJoinHash => j + } + assert(bhj.isEmpty, "BroadcastHashJoin still planned even though it is switched off") + + val shj = rdd.queryExecution.sparkPlan.collect { + case j: LeftSemiJoinHash => j + } + assert(shj.size === 1, + "LeftSemiJoinHash should be planned when BroadcastHashJoin is turned off") + + sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=$tmp") + } + + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index 8011f9b8773b3..4104df8f8e022 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -132,7 +132,7 @@ abstract class HiveComparisonTest def isSorted(plan: LogicalPlan): Boolean = plan match { case _: Join | _: Aggregate | _: Generate | _: Sample | _: Distinct => false - case PhysicalOperation(_, _, Sort(_, _)) => true + case PhysicalOperation(_, _, Sort(_, true, _)) => true case _ => plan.children.iterator.exists(isSorted) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 4d81acc753a27..fb6da33e88ef6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -56,6 +56,13 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { Locale.setDefault(originalLocale) } + test("SPARK-4908: concurent hive native commands") { + (1 to 100).par.map { _ => + sql("USE default") + sql("SHOW TABLES") + } + } + createQueryTest("constant object inspector for generic udf", """SELECT named_struct( lower("AA"), "10", diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala index a0ace91060a28..16f77a438e1ae 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala @@ -18,7 +18,8 @@ package org.apache.spark.sql.hive.execution import org.apache.spark.sql.hive.test.TestHive -import org.apache.spark.sql.{Row, SchemaRDD} +import org.apache.spark.sql.hive.test.TestHive._ +import org.apache.spark.sql.Row import org.apache.spark.util.Utils @@ -76,4 +77,15 @@ class HiveTableScanSuite extends HiveComparisonTest { === Array(Row(java.sql.Timestamp.valueOf("2014-12-11 00:00:00")),Row(null))) TestHive.sql("DROP TABLE timestamp_query_null") } + + test("Spark-4959 Attributes are case sensitive when using a select query from a projection") { + sql("create table spark_4959 (col1 string)") + sql("""insert into table spark_4959 select "hi" from src limit 1""") + table("spark_4959").select( + 'col1.as('CaseSensitiveColName), + 'col1.as('CaseSensitiveColName2)).registerTempTable("spark_4959_2") + + assert(sql("select CaseSensitiveColName from spark_4959_2").first() === Row("hi")) + assert(sql("select casesensitivecolname from spark_4959_2").first() === Row("hi")) + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index f57f31af15566..5d0fb7237011f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -32,6 +32,13 @@ case class Nested3(f3: Int) * valid, but Hive currently cannot execute it. */ class SQLQuerySuite extends QueryTest { + test("SPARK-4512 Fix attribute reference resolution error when using SORT BY") { + checkAnswer( + sql("SELECT * FROM (SELECT key + key AS a FROM src SORT BY value) t ORDER BY t.a"), + sql("SELECT key + key as a FROM src ORDER BY a").collect().toSeq + ) + } + test("CTAS with serde") { sql("CREATE TABLE ctas1 AS SELECT key k, value FROM src ORDER BY k, value").collect sql( diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ContextWaiter.scala b/streaming/src/main/scala/org/apache/spark/streaming/ContextWaiter.scala index a0aeacbc733bd..fdbbe2aa6ef08 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ContextWaiter.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ContextWaiter.scala @@ -17,30 +17,63 @@ package org.apache.spark.streaming +import java.util.concurrent.TimeUnit +import java.util.concurrent.locks.ReentrantLock + private[streaming] class ContextWaiter { + + private val lock = new ReentrantLock() + private val condition = lock.newCondition() + + // Guarded by "lock" private var error: Throwable = null - private var stopped: Boolean = false - def notifyError(e: Throwable) = synchronized { - error = e - notifyAll() - } + // Guarded by "lock" + private var stopped: Boolean = false - def notifyStop() = synchronized { - stopped = true - notifyAll() + def notifyError(e: Throwable): Unit = { + lock.lock() + try { + error = e + condition.signalAll() + } finally { + lock.unlock() + } } - def waitForStopOrError(timeout: Long = -1) = synchronized { - // If already had error, then throw it - if (error != null) { - throw error + def notifyStop(): Unit = { + lock.lock() + try { + stopped = true + condition.signalAll() + } finally { + lock.unlock() } + } - // If not already stopped, then wait - if (!stopped) { - if (timeout < 0) wait() else wait(timeout) + /** + * Return `true` if it's stopped; or throw the reported error if `notifyError` has been called; or + * `false` if the waiting time detectably elapsed before return from the method. + */ + def waitForStopOrError(timeout: Long = -1): Boolean = { + lock.lock() + try { + if (timeout < 0) { + while (!stopped && error == null) { + condition.await() + } + } else { + var nanos = TimeUnit.MILLISECONDS.toNanos(timeout) + while (!stopped && error == null && nanos > 0) { + nanos = condition.awaitNanos(nanos) + } + } + // If already had error, then throw it if (error != null) throw error + // already stopped or timeout + stopped + } finally { + lock.unlock() } } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index ecab5510a8e7b..8ef0787137845 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -22,7 +22,6 @@ import java.util.concurrent.atomic.AtomicInteger import scala.collection.Map import scala.collection.mutable.Queue -import scala.language.implicitConversions import scala.reflect.ClassTag import akka.actor.{Props, SupervisorStrategy} @@ -523,9 +522,11 @@ object StreamingContext extends Logging { private[streaming] val DEFAULT_CLEANER_TTL = 3600 - implicit def toPairDStreamFunctions[K, V](stream: DStream[(K, V)]) + @deprecated("Replaced by implicit functions in the DStream companion object. This is " + + "kept here only for backward compatibility.", "1.3.0") + def toPairDStreamFunctions[K, V](stream: DStream[(K, V)]) (implicit kt: ClassTag[K], vt: ClassTag[V], ord: Ordering[K] = null) = { - new PairDStreamFunctions[K, V](stream) + DStream.toPairDStreamFunctions(stream)(kt, vt, ord) } /** diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingSource.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingSource.scala index e35a568ddf115..9697437dd2fe5 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingSource.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingSource.scala @@ -29,9 +29,17 @@ private[streaming] class StreamingSource(ssc: StreamingContext) extends Source { private val streamingListener = ssc.progressListener private def registerGauge[T](name: String, f: StreamingJobProgressListener => T, - defaultValue: T) { + defaultValue: T): Unit = { + registerGaugeWithOption[T](name, + (l: StreamingJobProgressListener) => Option(f(streamingListener)), defaultValue) + } + + private def registerGaugeWithOption[T]( + name: String, + f: StreamingJobProgressListener => Option[T], + defaultValue: T): Unit = { metricRegistry.register(MetricRegistry.name("streaming", name), new Gauge[T] { - override def getValue: T = Option(f(streamingListener)).getOrElse(defaultValue) + override def getValue: T = f(streamingListener).getOrElse(defaultValue) }) } @@ -41,6 +49,12 @@ private[streaming] class StreamingSource(ssc: StreamingContext) extends Source { // Gauge for number of total completed batches registerGauge("totalCompletedBatches", _.numTotalCompletedBatches, 0L) + // Gauge for number of total received records + registerGauge("totalReceivedRecords", _.numTotalReceivedRecords, 0L) + + // Gauge for number of total processed records + registerGauge("totalProcessedRecords", _.numTotalProcessedRecords, 0L) + // Gauge for number of unprocessed batches registerGauge("unprocessedBatches", _.numUnprocessedBatches, 0L) @@ -55,19 +69,30 @@ private[streaming] class StreamingSource(ssc: StreamingContext) extends Source { // Gauge for last completed batch, useful for monitoring the streaming job's running status, // displayed data -1 for any abnormal condition. - registerGauge("lastCompletedBatch_submissionTime", - _.lastCompletedBatch.map(_.submissionTime).getOrElse(-1L), -1L) - registerGauge("lastCompletedBatch_processStartTime", - _.lastCompletedBatch.flatMap(_.processingStartTime).getOrElse(-1L), -1L) - registerGauge("lastCompletedBatch_processEndTime", - _.lastCompletedBatch.flatMap(_.processingEndTime).getOrElse(-1L), -1L) + registerGaugeWithOption("lastCompletedBatch_submissionTime", + _.lastCompletedBatch.map(_.submissionTime), -1L) + registerGaugeWithOption("lastCompletedBatch_processingStartTime", + _.lastCompletedBatch.flatMap(_.processingStartTime), -1L) + registerGaugeWithOption("lastCompletedBatch_processingEndTime", + _.lastCompletedBatch.flatMap(_.processingEndTime), -1L) + + // Gauge for last completed batch's delay information. + registerGaugeWithOption("lastCompletedBatch_processingDelay", + _.lastCompletedBatch.flatMap(_.processingDelay), -1L) + registerGaugeWithOption("lastCompletedBatch_schedulingDelay", + _.lastCompletedBatch.flatMap(_.schedulingDelay), -1L) + registerGaugeWithOption("lastCompletedBatch_totalDelay", + _.lastCompletedBatch.flatMap(_.totalDelay), -1L) // Gauge for last received batch, useful for monitoring the streaming job's running status, // displayed data -1 for any abnormal condition. - registerGauge("lastReceivedBatch_submissionTime", - _.lastCompletedBatch.map(_.submissionTime).getOrElse(-1L), -1L) - registerGauge("lastReceivedBatch_processStartTime", - _.lastCompletedBatch.flatMap(_.processingStartTime).getOrElse(-1L), -1L) - registerGauge("lastReceivedBatch_processEndTime", - _.lastCompletedBatch.flatMap(_.processingEndTime).getOrElse(-1L), -1L) + registerGaugeWithOption("lastReceivedBatch_submissionTime", + _.lastCompletedBatch.map(_.submissionTime), -1L) + registerGaugeWithOption("lastReceivedBatch_processingStartTime", + _.lastCompletedBatch.flatMap(_.processingStartTime), -1L) + registerGaugeWithOption("lastReceivedBatch_processingEndTime", + _.lastCompletedBatch.flatMap(_.processingEndTime), -1L) + + // Gauge for last received batch records. + registerGauge("lastReceivedBatch_records", _.lastReceivedBatchRecords.values.sum, 0L) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala index bb44b906d7386..de124cf40eff1 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala @@ -36,7 +36,6 @@ import org.apache.spark.api.java.function.{Function => JFunction, Function2 => J import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming._ -import org.apache.spark.streaming.StreamingContext._ import org.apache.spark.streaming.dstream.DStream /** @@ -815,6 +814,6 @@ object JavaPairDStream { def scalaToJavaLong[K: ClassTag](dstream: JavaPairDStream[K, Long]) : JavaPairDStream[K, JLong] = { - StreamingContext.toPairDStreamFunctions(dstream.dstream).mapValues(new JLong(_)) + DStream.toPairDStreamFunctions(dstream.dstream).mapValues(new JLong(_)) } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala index dbf1ebbaf653a..7f8651e719d84 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala @@ -20,8 +20,8 @@ package org.apache.spark.streaming.dstream import java.io.{IOException, ObjectInputStream, ObjectOutputStream} -import scala.deprecated import scala.collection.mutable.HashMap +import scala.language.implicitConversions import scala.reflect.ClassTag import scala.util.matching.Regex @@ -29,7 +29,7 @@ import org.apache.spark.{Logging, SparkException} import org.apache.spark.rdd.{BlockRDD, RDD} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming._ -import org.apache.spark.streaming.StreamingContext._ +import org.apache.spark.streaming.StreamingContext.rddToFileName import org.apache.spark.streaming.scheduler.Job import org.apache.spark.util.{CallSite, MetadataCleaner, Utils} @@ -48,8 +48,7 @@ import org.apache.spark.util.{CallSite, MetadataCleaner, Utils} * `window`. In addition, [[org.apache.spark.streaming.dstream.PairDStreamFunctions]] contains * operations available only on DStreams of key-value pairs, such as `groupByKeyAndWindow` and * `join`. These operations are automatically available on any DStream of pairs - * (e.g., DStream[(Int, Int)] through implicit conversions when - * `org.apache.spark.streaming.StreamingContext._` is imported. + * (e.g., DStream[(Int, Int)] through implicit conversions. * * DStreams internally is characterized by a few basic properties: * - A list of other DStreams that the DStream depends on @@ -802,10 +801,21 @@ abstract class DStream[T: ClassTag] ( } } -private[streaming] object DStream { +object DStream { + + // `toPairDStreamFunctions` was in SparkContext before 1.3 and users had to + // `import StreamingContext._` to enable it. Now we move it here to make the compiler find + // it automatically. However, we still keep the old function in StreamingContext for backward + // compatibility and forward to the following function directly. + + implicit def toPairDStreamFunctions[K, V](stream: DStream[(K, V)]) + (implicit kt: ClassTag[K], vt: ClassTag[V], ord: Ordering[K] = null): + PairDStreamFunctions[K, V] = { + new PairDStreamFunctions[K, V](stream) + } /** Get the creation site of a DStream from the stack trace of when the DStream is created. */ - def getCreationSite(): CallSite = { + private[streaming] def getCreationSite(): CallSite = { val SPARK_CLASS_REGEX = """^org\.apache\.spark""".r val SPARK_STREAMING_TESTCLASS_REGEX = """^org\.apache\.spark\.streaming\.test""".r val SPARK_EXAMPLES_CLASS_REGEX = """^org\.apache\.spark\.examples""".r diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala index 98539e06b4e29..8a58571632447 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala @@ -27,12 +27,10 @@ import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat} import org.apache.spark.{HashPartitioner, Partitioner, SerializableWritable} import org.apache.spark.rdd.RDD import org.apache.spark.streaming.{Duration, Time} -import org.apache.spark.streaming.StreamingContext._ +import org.apache.spark.streaming.StreamingContext.rddToFileName /** * Extra functions available on DStream of (key, value) pairs through an implicit conversion. - * Import `org.apache.spark.streaming.StreamingContext._` at the top of your program to use - * these functions. */ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) (implicit kt: ClassTag[K], vt: ClassTag[V], ord: Ordering[K]) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReducedWindowedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReducedWindowedDStream.scala index 1a47089e513c4..c0a5af0b65cc3 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReducedWindowedDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReducedWindowedDStream.scala @@ -17,8 +17,6 @@ package org.apache.spark.streaming.dstream -import org.apache.spark.streaming.StreamingContext._ - import org.apache.spark.rdd.RDD import org.apache.spark.rdd.{CoGroupedRDD, MapPartitionsRDD} import org.apache.spark.Partitioner diff --git a/streaming/src/main/scala/org/apache/spark/streaming/package.scala b/streaming/src/main/scala/org/apache/spark/streaming/package.scala index 4dd985cf5a178..2153ae0d34184 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/package.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/package.scala @@ -26,7 +26,7 @@ package org.apache.spark * available only on DStreams * of key-value pairs, such as `groupByKey` and `reduceByKey`. These operations are automatically * available on any DStream of the right type (e.g. DStream[(Int, Int)] through implicit - * conversions when you `import org.apache.spark.streaming.StreamingContext._`. + * conversions. * * For the Java API of Spark Streaming, take a look at the * [[org.apache.spark.streaming.api.java.JavaStreamingContext]] which serves as the entry point, and diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala index 55765dc90698b..79263a7183977 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala @@ -116,7 +116,7 @@ private[streaming] class BlockGenerator( /** * Push a single data item into the buffer. After buffering the data, the - * `BlockGeneratorListnere.onAddData` callback will be called. All received data items + * `BlockGeneratorListener.onAddData` callback will be called. All received data items * will be periodically pushed into BlockManager. */ def addDataWithCallback(data: Any, metadata: Any) = synchronized { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala index c0670e22a7aee..8b97db8dd36f1 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala @@ -187,10 +187,7 @@ private[streaming] class WriteAheadLogBasedBlockHandler( } // Combine the futures, wait for both to complete, and return the write ahead log segment - val combinedFuture = for { - _ <- storeInBlockManagerFuture - fileSegment <- storeInWriteAheadLogFuture - } yield fileSegment + val combinedFuture = storeInBlockManagerFuture.zip(storeInWriteAheadLogFuture).map(_._2) val segment = Await.result(combinedFuture, blockStoreTimeout) WriteAheadLogBasedStoreResult(blockId, segment) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala index f61069b56db5e..5ee53a5c5f561 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala @@ -25,7 +25,6 @@ import org.apache.spark.streaming.scheduler.StreamingListenerBatchStarted import org.apache.spark.streaming.scheduler.BatchInfo import org.apache.spark.streaming.scheduler.StreamingListenerBatchSubmitted import org.apache.spark.util.Distribution -import org.apache.spark.Logging private[streaming] class StreamingJobProgressListener(ssc: StreamingContext) @@ -36,6 +35,8 @@ private[streaming] class StreamingJobProgressListener(ssc: StreamingContext) private val completedaBatchInfos = new Queue[BatchInfo] private val batchInfoLimit = ssc.conf.getInt("spark.streaming.ui.retainedBatches", 100) private var totalCompletedBatches = 0L + private var totalReceivedRecords = 0L + private var totalProcessedRecords = 0L private val receiverInfos = new HashMap[Int, ReceiverInfo] val batchDuration = ssc.graph.batchDuration.milliseconds @@ -65,6 +66,10 @@ private[streaming] class StreamingJobProgressListener(ssc: StreamingContext) override def onBatchStarted(batchStarted: StreamingListenerBatchStarted) = synchronized { runningBatchInfos(batchStarted.batchInfo.batchTime) = batchStarted.batchInfo waitingBatchInfos.remove(batchStarted.batchInfo.batchTime) + + batchStarted.batchInfo.receivedBlockInfo.foreach { case (_, infos) => + totalReceivedRecords += infos.map(_.numRecords).sum + } } override def onBatchCompleted(batchCompleted: StreamingListenerBatchCompleted) = synchronized { @@ -73,6 +78,10 @@ private[streaming] class StreamingJobProgressListener(ssc: StreamingContext) completedaBatchInfos.enqueue(batchCompleted.batchInfo) if (completedaBatchInfos.size > batchInfoLimit) completedaBatchInfos.dequeue() totalCompletedBatches += 1L + + batchCompleted.batchInfo.receivedBlockInfo.foreach { case (_, infos) => + totalProcessedRecords += infos.map(_.numRecords).sum + } } def numReceivers = synchronized { @@ -83,6 +92,14 @@ private[streaming] class StreamingJobProgressListener(ssc: StreamingContext) totalCompletedBatches } + def numTotalReceivedRecords: Long = synchronized { + totalReceivedRecords + } + + def numTotalProcessedRecords: Long = synchronized { + totalProcessedRecords + } + def numUnprocessedBatches: Long = synchronized { waitingBatchInfos.size + runningBatchInfos.size } diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java index ce645fccba1d0..12cc0de7509d6 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java @@ -57,7 +57,7 @@ public void equalIterable(Iterable a, Iterable b) { @Test public void testInitialization() { - Assert.assertNotNull(ssc.sc()); + Assert.assertNotNull(ssc.sparkContext()); } @SuppressWarnings("unchecked") @@ -662,7 +662,7 @@ public void testStreamingContextTransform(){ listOfDStreams1, new Function2>, Time, JavaRDD>() { public JavaRDD call(List> listOfRDDs, Time time) { - assert(listOfRDDs.size() == 2); + Assert.assertEquals(2, listOfRDDs.size()); return null; } } @@ -675,7 +675,7 @@ public JavaRDD call(List> listOfRDDs, Time time) { listOfDStreams2, new Function2>, Time, JavaPairRDD>>() { public JavaPairRDD> call(List> listOfRDDs, Time time) { - assert(listOfRDDs.size() == 3); + Assert.assertEquals(3, listOfRDDs.size()); JavaRDD rdd1 = (JavaRDD)listOfRDDs.get(0); JavaRDD rdd2 = (JavaRDD)listOfRDDs.get(1); JavaRDD> rdd3 = (JavaRDD>)listOfRDDs.get(2); @@ -969,7 +969,7 @@ public Integer call(Tuple2 in) throws Exception { }); JavaTestUtils.attachTestOutputStream(reversed); - List>> result = JavaTestUtils.runStreams(ssc, 2, 2); + List> result = JavaTestUtils.runStreams(ssc, 2, 2); Assert.assertEquals(expected, result); } @@ -1012,7 +1012,7 @@ public Iterable> call(Tuple2 in) throws } }); JavaTestUtils.attachTestOutputStream(flatMapped); - List>> result = JavaTestUtils.runStreams(ssc, 2, 2); + List>> result = JavaTestUtils.runStreams(ssc, 2, 2); Assert.assertEquals(expected, result); } @@ -1163,9 +1163,9 @@ public void testGroupByKeyAndWindow() { JavaTestUtils.attachTestOutputStream(groupWindowed); List>>> result = JavaTestUtils.runStreams(ssc, 3, 3); - assert(result.size() == expected.size()); + Assert.assertEquals(expected.size(), result.size()); for (int i = 0; i < result.size(); i++) { - assert(convert(result.get(i)).equals(convert(expected.get(i)))); + Assert.assertEquals(convert(expected.get(i)), convert(result.get(i))); } } @@ -1383,7 +1383,7 @@ public JavaPairRDD call(JavaPairRDD in) thro }); JavaTestUtils.attachTestOutputStream(sorted); - List>> result = JavaTestUtils.runStreams(ssc, 2, 2); + List>> result = JavaTestUtils.runStreams(ssc, 2, 2); Assert.assertEquals(expected, result); } diff --git a/streaming/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java b/streaming/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java index 6e1f01900071b..1e24da7f5f60c 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java +++ b/streaming/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java @@ -17,6 +17,7 @@ package org.apache.spark.streaming; +import org.apache.spark.SparkConf; import org.apache.spark.streaming.api.java.JavaStreamingContext; import org.junit.After; import org.junit.Before; @@ -27,8 +28,11 @@ public abstract class LocalJavaStreamingContext { @Before public void setUp() { - System.setProperty("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock"); - ssc = new JavaStreamingContext("local[2]", "test", new Duration(1000)); + SparkConf conf = new SparkConf() + .setMaster("local[2]") + .setAppName("test") + .set("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock"); + ssc = new JavaStreamingContext(conf, new Duration(1000)); ssc.checkpoint("checkpoint"); } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala index 86b96785d7b87..199f5e7161124 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala @@ -28,7 +28,6 @@ import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.SparkContext._ import org.apache.spark.rdd.{BlockRDD, RDD} import org.apache.spark.storage.StorageLevel -import org.apache.spark.streaming.StreamingContext._ import org.apache.spark.streaming.dstream.{DStream, WindowedDStream} import org.apache.spark.HashPartitioner diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala index c97998add8ffa..72d055eb2ea31 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala @@ -30,7 +30,6 @@ import org.apache.hadoop.io.{IntWritable, Text} import org.apache.hadoop.mapred.TextOutputFormat import org.apache.hadoop.mapreduce.lib.output.{TextOutputFormat => NewTextOutputFormat} -import org.apache.spark.streaming.StreamingContext._ import org.apache.spark.streaming.dstream.{DStream, FileInputDStream} import org.apache.spark.streaming.util.ManualClock import org.apache.spark.util.Utils diff --git a/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala b/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala index 5dbb7232009eb..e0f14fd954280 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala @@ -20,7 +20,6 @@ package org.apache.spark.streaming import org.apache.spark.Logging import org.apache.spark.streaming.dstream.DStream import org.apache.spark.util.Utils -import org.apache.spark.streaming.StreamingContext._ import scala.util.Random import scala.collection.mutable.ArrayBuffer diff --git a/streaming/src/test/scala/org/apache/spark/streaming/WindowOperationsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/WindowOperationsSuite.scala index 471c99fab4682..a5d2bb2fde16c 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/WindowOperationsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/WindowOperationsSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark.streaming -import org.apache.spark.streaming.StreamingContext._ import org.apache.spark.streaming.dstream.DStream import org.apache.spark.storage.StorageLevel diff --git a/streaming/src/test/scala/org/apache/spark/streamingtest/ImplicitSuite.scala b/streaming/src/test/scala/org/apache/spark/streamingtest/ImplicitSuite.scala new file mode 100644 index 0000000000000..d0bf328f2b74d --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streamingtest/ImplicitSuite.scala @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streamingtest + +/** + * A test suite to make sure all `implicit` functions work correctly. + * + * As `implicit` is a compiler feature, we don't need to run this class. + * What we need to do is making the compiler happy. + */ +class ImplicitSuite { + + // We only want to test if `implict` works well with the compiler, so we don't need a real DStream. + def mockDStream[T]: org.apache.spark.streaming.dstream.DStream[T] = null + + def testToPairDStreamFunctions(): Unit = { + val dstream: org.apache.spark.streaming.dstream.DStream[(Int, Int)] = mockDStream + dstream.groupByKey() + } +} diff --git a/tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala b/tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala index db58eb642b56d..15ee95070a3d3 100644 --- a/tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala +++ b/tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala @@ -21,7 +21,7 @@ import java.util.concurrent.{CountDownLatch, Executors} import java.util.concurrent.atomic.AtomicLong import org.apache.spark.executor.ShuffleWriteMetrics -import org.apache.spark.SparkContext +import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.serializer.KryoSerializer import org.apache.spark.shuffle.hash.HashShuffleManager import org.apache.spark.util.Utils @@ -49,13 +49,13 @@ object StoragePerfTester { val writeData = "1" * recordLength val executor = Executors.newFixedThreadPool(numMaps) - System.setProperty("spark.shuffle.compress", "false") - System.setProperty("spark.shuffle.sync", "true") - System.setProperty("spark.shuffle.manager", - "org.apache.spark.shuffle.hash.HashShuffleManager") + val conf = new SparkConf() + .set("spark.shuffle.compress", "false") + .set("spark.shuffle.sync", "true") + .set("spark.shuffle.manager", "org.apache.spark.shuffle.hash.HashShuffleManager") // This is only used to instantiate a BlockManager. All thread scheduling is done manually. - val sc = new SparkContext("local[4]", "Write Tester") + val sc = new SparkContext("local[4]", "Write Tester", conf) val hashShuffleManager = sc.env.shuffleManager.asInstanceOf[HashShuffleManager] def writeOutputBytes(mapId: Int, total: AtomicLong) = { diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index b2e45435c4abb..9c77dff48dc8b 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -311,7 +311,7 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments, private def cleanupStagingDir(fs: FileSystem) { var stagingDirPath: Path = null try { - val preserveFiles = sparkConf.get("spark.yarn.preserve.staging.files", "false").toBoolean + val preserveFiles = sparkConf.getBoolean("spark.yarn.preserve.staging.files", false) if (!preserveFiles) { stagingDirPath = new Path(System.getenv("SPARK_YARN_STAGING_DIR")) if (stagingDirPath == null) { diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala index 7305249f80e83..39f1021c9d942 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala @@ -39,6 +39,8 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) var appName: String = "Spark" var priority = 0 + parseArgs(args.toList) + // Additional memory to allocate to containers // For now, use driver's memory overhead as our AM container's memory overhead val amMemoryOverhead = sparkConf.getInt("spark.yarn.driver.memoryOverhead", @@ -50,7 +52,6 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) private val isDynamicAllocationEnabled = sparkConf.getBoolean("spark.dynamicAllocation.enabled", false) - parseArgs(args.toList) loadEnvironmentArgs() validateArgs() diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala index 5f0c67f05c9dd..eb97a7b3c59a4 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala @@ -693,7 +693,7 @@ private[spark] object ClientBase extends Logging { addClasspathEntry(Environment.PWD.$(), env) // Normally the users app.jar is last in case conflicts with spark jars - if (sparkConf.get("spark.yarn.user.classpath.first", "false").toBoolean) { + if (sparkConf.getBoolean("spark.yarn.user.classpath.first", false)) { addUserClasspath(args, sparkConf, env) addFileToClasspath(sparkJar(sparkConf), SPARK_JAR, env) populateHadoopClasspath(conf, env)