From 797f8a000773d848fa52c7fe2eb1b5e5e7f6c55a Mon Sep 17 00:00:00 2001 From: Pierre Borckmans Date: Thu, 19 Mar 2015 08:02:06 -0400 Subject: [PATCH 01/47] [SPARK-6402][DOC] - Remove some refererences to shark in docs and ec2 EC2 script and job scheduling documentation still refered to Shark. I removed these references. I also removed a remaining `SHARK_VERSION` variable from `ec2-variables.sh`. Author: Pierre Borckmans Closes #5083 from pierre-borckmans/remove_refererences_to_shark_in_docs and squashes the following commits: 4e90ffc [Pierre Borckmans] Removed deprecated SHARK_VERSION caea407 [Pierre Borckmans] Remove shark reference from ec2 script doc 196c744 [Pierre Borckmans] Removed references to Shark --- docs/ec2-scripts.md | 2 +- docs/job-scheduling.md | 6 ++---- ec2/deploy.generic/root/spark-ec2/ec2-variables.sh | 1 - 3 files changed, 3 insertions(+), 6 deletions(-) diff --git a/docs/ec2-scripts.md b/docs/ec2-scripts.md index 8c9a1e1262d8f..7f60f82b966fe 100644 --- a/docs/ec2-scripts.md +++ b/docs/ec2-scripts.md @@ -5,7 +5,7 @@ title: Running Spark on EC2 The `spark-ec2` script, located in Spark's `ec2` directory, allows you to launch, manage and shut down Spark clusters on Amazon EC2. It automatically -sets up Spark, Shark and HDFS on the cluster for you. This guide describes +sets up Spark and HDFS on the cluster for you. This guide describes how to use `spark-ec2` to launch clusters, how to run jobs on them, and how to shut them down. It assumes you've already signed up for an EC2 account on the [Amazon Web Services site](http://aws.amazon.com/). diff --git a/docs/job-scheduling.md b/docs/job-scheduling.md index 5295e351dd711..963e88a3e1d8f 100644 --- a/docs/job-scheduling.md +++ b/docs/job-scheduling.md @@ -14,8 +14,7 @@ runs an independent set of executor processes. The cluster managers that Spark r facilities for [scheduling across applications](#scheduling-across-applications). Second, _within_ each Spark application, multiple "jobs" (Spark actions) may be running concurrently if they were submitted by different threads. This is common if your application is serving requests -over the network; for example, the [Shark](http://shark.cs.berkeley.edu) server works this way. Spark -includes a [fair scheduler](#scheduling-within-an-application) to schedule resources within each SparkContext. +over the network. Spark includes a [fair scheduler](#scheduling-within-an-application) to schedule resources within each SparkContext. # Scheduling Across Applications @@ -52,8 +51,7 @@ an application to gain back cores on one node when it has work to do. To use thi Note that none of the modes currently provide memory sharing across applications. If you would like to share data this way, we recommend running a single server application that can serve multiple requests by querying -the same RDDs. For example, the [Shark](http://shark.cs.berkeley.edu) JDBC server works this way for SQL -queries. In future releases, in-memory storage systems such as [Tachyon](http://tachyon-project.org) will +the same RDDs. In future releases, in-memory storage systems such as [Tachyon](http://tachyon-project.org) will provide another approach to share RDDs. ## Dynamic Resource Allocation diff --git a/ec2/deploy.generic/root/spark-ec2/ec2-variables.sh b/ec2/deploy.generic/root/spark-ec2/ec2-variables.sh index 0857657152ec7..4f3e8da809f7f 100644 --- a/ec2/deploy.generic/root/spark-ec2/ec2-variables.sh +++ b/ec2/deploy.generic/root/spark-ec2/ec2-variables.sh @@ -25,7 +25,6 @@ export MAPRED_LOCAL_DIRS="{{mapred_local_dirs}}" export SPARK_LOCAL_DIRS="{{spark_local_dirs}}" export MODULES="{{modules}}" export SPARK_VERSION="{{spark_version}}" -export SHARK_VERSION="{{shark_version}}" export TACHYON_VERSION="{{tachyon_version}}" export HADOOP_MAJOR_VERSION="{{hadoop_major_version}}" export SWAP_MB="{{swap}}" From 3c4e486b9c8d3f256e801db7c176ab650c976135 Mon Sep 17 00:00:00 2001 From: mcheah Date: Thu, 19 Mar 2015 08:51:49 -0400 Subject: [PATCH 02/47] [SPARK-5843] [API] Allowing map-side combine to be specified in Java. Specifically, when calling JavaPairRDD.combineByKey(), there is a new six-parameter method that exposes the map-side-combine boolean as the fifth parameter and the serializer as the sixth parameter. Author: mcheah Closes #4634 from mccheah/pair-rdd-map-side-combine and squashes the following commits: 5c58319 [mcheah] Fixing compiler errors. 3ce7deb [mcheah] Addressing style and documentation comments. 7455c7a [mcheah] Allowing Java combineByKey to specify Serializer as well. 6ddd729 [mcheah] [SPARK-5843] Allowing map-side combine to be specified in Java. --- .../apache/spark/api/java/JavaPairRDD.scala | 46 ++++++++++++---- .../java/org/apache/spark/JavaAPISuite.java | 53 +++++++++++++++++-- 2 files changed, 87 insertions(+), 12 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala index 4eadc9a85613e..a023712be1166 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala @@ -39,6 +39,7 @@ import org.apache.spark.api.java.function.{Function => JFunction, Function2 => J import org.apache.spark.partial.{BoundedDouble, PartialResult} import org.apache.spark.rdd.{OrderedRDDFunctions, RDD} import org.apache.spark.rdd.RDD.rddToPairRDDFunctions +import org.apache.spark.serializer.Serializer import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils @@ -227,24 +228,51 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) * - `mergeValue`, to merge a V into a C (e.g., adds it to the end of a list) * - `mergeCombiners`, to combine two C's into a single one. * - * In addition, users can control the partitioning of the output RDD, and whether to perform - * map-side aggregation (if a mapper can produce multiple items with the same key). + * In addition, users can control the partitioning of the output RDD, the serializer that is use + * for the shuffle, and whether to perform map-side aggregation (if a mapper can produce multiple + * items with the same key). */ def combineByKey[C](createCombiner: JFunction[V, C], - mergeValue: JFunction2[C, V, C], - mergeCombiners: JFunction2[C, C, C], - partitioner: Partitioner): JavaPairRDD[K, C] = { - implicit val ctag: ClassTag[C] = fakeClassTag + mergeValue: JFunction2[C, V, C], + mergeCombiners: JFunction2[C, C, C], + partitioner: Partitioner, + mapSideCombine: Boolean, + serializer: Serializer): JavaPairRDD[K, C] = { + implicit val ctag: ClassTag[C] = fakeClassTag fromRDD(rdd.combineByKey( createCombiner, mergeValue, mergeCombiners, - partitioner + partitioner, + mapSideCombine, + serializer )) } /** - * Simplified version of combineByKey that hash-partitions the output RDD. + * Generic function to combine the elements for each key using a custom set of aggregation + * functions. Turns a JavaPairRDD[(K, V)] into a result of type JavaPairRDD[(K, C)], for a + * "combined type" C * Note that V and C can be different -- for example, one might group an + * RDD of type (Int, Int) into an RDD of type (Int, List[Int]). Users provide three + * functions: + * + * - `createCombiner`, which turns a V into a C (e.g., creates a one-element list) + * - `mergeValue`, to merge a V into a C (e.g., adds it to the end of a list) + * - `mergeCombiners`, to combine two C's into a single one. + * + * In addition, users can control the partitioning of the output RDD. This method automatically + * uses map-side aggregation in shuffling the RDD. + */ + def combineByKey[C](createCombiner: JFunction[V, C], + mergeValue: JFunction2[C, V, C], + mergeCombiners: JFunction2[C, C, C], + partitioner: Partitioner): JavaPairRDD[K, C] = { + combineByKey(createCombiner, mergeValue, mergeCombiners, partitioner, true, null) + } + + /** + * Simplified version of combineByKey that hash-partitions the output RDD and uses map-side + * aggregation. */ def combineByKey[C](createCombiner: JFunction[V, C], mergeValue: JFunction2[C, V, C], @@ -488,7 +516,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) /** * Simplified version of combineByKey that hash-partitions the resulting RDD using the existing - * partitioner/parallelism level. + * partitioner/parallelism level and using map-side aggregation. */ def combineByKey[C](createCombiner: JFunction[V, C], mergeValue: JFunction2[C, V, C], diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java index 8ec54360ca42a..d4b5bb519157c 100644 --- a/core/src/test/java/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java @@ -24,11 +24,12 @@ import java.util.*; import java.util.concurrent.*; -import org.apache.spark.input.PortableDataStream; +import scala.collection.JavaConversions; import scala.Tuple2; import scala.Tuple3; import scala.Tuple4; +import com.google.common.collect.ImmutableMap; import com.google.common.collect.Iterables; import com.google.common.collect.Iterators; import com.google.common.collect.Lists; @@ -51,8 +52,11 @@ import org.apache.spark.api.java.*; import org.apache.spark.api.java.function.*; import org.apache.spark.executor.TaskMetrics; +import org.apache.spark.input.PortableDataStream; import org.apache.spark.partial.BoundedDouble; import org.apache.spark.partial.PartialResult; +import org.apache.spark.rdd.RDD; +import org.apache.spark.serializer.KryoSerializer; import org.apache.spark.storage.StorageLevel; import org.apache.spark.util.StatCounter; @@ -726,8 +730,8 @@ public void javaDoubleRDDHistoGram() { Tuple2 results = rdd.histogram(2); double[] expected_buckets = {1.0, 2.5, 4.0}; long[] expected_counts = {2, 2}; - Assert.assertArrayEquals(expected_buckets, results._1, 0.1); - Assert.assertArrayEquals(expected_counts, results._2); + Assert.assertArrayEquals(expected_buckets, results._1(), 0.1); + Assert.assertArrayEquals(expected_counts, results._2()); // Test with provided buckets long[] histogram = rdd.histogram(expected_buckets); Assert.assertArrayEquals(expected_counts, histogram); @@ -1424,6 +1428,49 @@ public void checkpointAndRestore() { Assert.assertEquals(Arrays.asList(1, 2, 3, 4, 5), recovered.collect()); } + @Test + public void combineByKey() { + JavaRDD originalRDD = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5, 6)); + Function keyFunction = new Function() { + @Override + public Integer call(Integer v1) throws Exception { + return v1 % 3; + } + }; + Function createCombinerFunction = new Function() { + @Override + public Integer call(Integer v1) throws Exception { + return v1; + } + }; + + Function2 mergeValueFunction = new Function2() { + @Override + public Integer call(Integer v1, Integer v2) throws Exception { + return v1 + v2; + } + }; + + JavaPairRDD combinedRDD = originalRDD.keyBy(keyFunction) + .combineByKey(createCombinerFunction, mergeValueFunction, mergeValueFunction); + Map results = combinedRDD.collectAsMap(); + ImmutableMap expected = ImmutableMap.of(0, 9, 1, 5, 2, 7); + Assert.assertEquals(expected, results); + + Partitioner defaultPartitioner = Partitioner.defaultPartitioner( + combinedRDD.rdd(), JavaConversions.asScalaBuffer(Lists.>newArrayList())); + combinedRDD = originalRDD.keyBy(keyFunction) + .combineByKey( + createCombinerFunction, + mergeValueFunction, + mergeValueFunction, + defaultPartitioner, + false, + new KryoSerializer(new SparkConf())); + results = combinedRDD.collectAsMap(); + Assert.assertEquals(expected, results); + } + @SuppressWarnings("unchecked") @Test public void mapOnPairRDD() { From dda4dedca0459fc7c00eb1d9cb07e14af1621e0f Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Thu, 19 Mar 2015 11:10:20 -0400 Subject: [PATCH 03/47] [SPARK-6291] [MLLIB] GLM toString & toDebugString GLM toString prints out intercept, numFeatures. For LogisticRegression and SVM model, toString also prints out numClasses, threshold. GLM toDebugString prints out the whole weights, intercept. Author: Yanbo Liang Closes #5038 from yanboliang/spark-6291 and squashes the following commits: 2f578b0 [Yanbo Liang] code format 78b33f2 [Yanbo Liang] fix typos 1e8a023 [Yanbo Liang] GLM toString & toDebugString --- .../spark/mllib/classification/LogisticRegression.scala | 4 ++++ .../scala/org/apache/spark/mllib/classification/SVM.scala | 4 ++++ .../mllib/regression/GeneralizedLinearAlgorithm.scala | 7 ++++++- 3 files changed, 14 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala index b787667b018e6..e7c3599ff619c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala @@ -163,6 +163,10 @@ class LogisticRegressionModel ( } override protected def formatVersion: String = "1.0" + + override def toString: String = { + s"${super.toString}, numClasses = ${numClasses}, threshold = ${threshold.get}" + } } object LogisticRegressionModel extends Loader[LogisticRegressionModel] { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala index cfc7f868a02f0..52fb62dcff1b4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala @@ -86,6 +86,10 @@ class SVMModel ( } override protected def formatVersion: String = "1.0" + + override def toString: String = { + s"${super.toString}, numClasses = 2, threshold = ${threshold.get}" + } } object SVMModel extends Loader[SVMModel] { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala index b262bec904525..45b9ebb4cc0d6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala @@ -76,7 +76,12 @@ abstract class GeneralizedLinearModel(val weights: Vector, val intercept: Double predictPoint(testData, weights, intercept) } - override def toString() = "(weights=%s, intercept=%s)".format(weights, intercept) + /** + * Print a summary of the model. + */ + override def toString: String = { + s"${this.getClass.getName}: intercept = ${intercept}, numFeatures = ${weights.size}" + } } /** From 8cb23a1f9a3ed08e57865bcb6cc1cc7902881073 Mon Sep 17 00:00:00 2001 From: Brennon York Date: Thu, 19 Mar 2015 11:18:24 -0400 Subject: [PATCH 04/47] [SPARK-5313][Project Infra]: Create simple framework for highlighting changes introduced in a PR Built a simple framework with a `dev/tests` directory to house all pull request related tests. I've moved the two original tests (`pr_merge_ability` and `pr_public_classes`) into the new `dev/tests` directory and tested to the best of my ability. At this point I need to test against Jenkins actually running the new `run-tests-jenkins` script to ensure things aren't broken down the path. Author: Brennon York Closes #5072 from brennonyork/SPARK-5313 and squashes the following commits: 8ae990c [Brennon York] added dev/run-tests back, removed echo 5db4ed4 [Brennon York] removed the git checkout 1b50050 [Brennon York] adding echos to see what jenkins is seeing b823959 [Brennon York] removed run-tests to further test the public_classes pr test 2b9ce12 [Brennon York] added the dev/run-tests call back in ffd49c0 [Brennon York] remove -c from bash as that was removing the trailing args 735d615 [Brennon York] removed the actual dev/run-tests command to further test jenkins d579662 [Brennon York] Merge remote-tracking branch 'upstream/master' into SPARK-5313 aa48029 [Brennon York] removed echo lines for testing jenkins 24cd965 [Brennon York] added test output to check within jenkins to verify 3a38e73 [Brennon York] removed the temporary read 9c881ff [Brennon York] updated test suite 183b7ee [Brennon York] added documentation on how to create tests 0bc2efe [Brennon York] ensure each test starts on the current pr branch 1743378 [Brennon York] added tests in test suite abd7430 [Brennon York] updated to include test suite --- dev/run-tests-jenkins | 75 ++++++++++++++-------------------- dev/tests/pr_merge_ability.sh | 39 ++++++++++++++++++ dev/tests/pr_public_classes.sh | 65 +++++++++++++++++++++++++++++ 3 files changed, 135 insertions(+), 44 deletions(-) create mode 100755 dev/tests/pr_merge_ability.sh create mode 100755 dev/tests/pr_public_classes.sh diff --git a/dev/run-tests-jenkins b/dev/run-tests-jenkins index 6a849e4f77207..5f4000e83925c 100755 --- a/dev/run-tests-jenkins +++ b/dev/run-tests-jenkins @@ -49,6 +49,21 @@ SHORT_COMMIT_HASH="${ghprbActualCommit:0:7}" TESTS_TIMEOUT="120m" # format: http://linux.die.net/man/1/timeout +# Array to capture all tests to run on the pull request. These tests are held under the +#+ dev/tests/ directory. +# +# To write a PR test: +#+ * the file must reside within the dev/tests directory +#+ * be an executable bash script +#+ * accept two arguments on the command line, the first being the Github PR long commit +#+ hash and the second the Github SHA1 hash +#+ * and, lastly, return string output to be included in the pr message output that will +#+ be posted to Github +PR_TESTS=( + "pr_merge_ability" + "pr_public_classes" +) + function post_message () { local message=$1 local data="{\"body\": \"$message\"}" @@ -131,48 +146,22 @@ function send_archived_logs () { fi } - -# We diff master...$ghprbActualCommit because that gets us changes introduced in the PR -#+ and not anything else added to master since the PR was branched. - -# check PR merge-ability and check for new public classes -{ - if [ "$sha1" == "$ghprbActualCommit" ]; then - merge_note=" * This patch **does not merge cleanly**." - else - merge_note=" * This patch merges cleanly." +# Environment variable to capture PR test output +pr_message="" + +# Run pull request tests +for t in "${PR_TESTS[@]}"; do + this_test="${FWDIR}/dev/tests/${t}.sh" + # Ensure the test is a file and is executable + if [ -x "$this_test" ]; then + echo "ghprb: $ghprbActualCommit sha1: $sha1" + this_mssg="`bash \"${this_test}\" \"${ghprbActualCommit}\" \"${sha1}\" 2>/dev/null`" + # Check if this is the merge test as we submit that note *before* and *after* + # the tests run + [ "$t" == "pr_merge_ability" ] && merge_note="${this_mssg}" + pr_message="${pr_message}\n${this_mssg}" fi - - source_files=$( - git diff master...$ghprbActualCommit --name-only `# diff patch against master from branch point` \ - | grep -v -e "\/test" `# ignore files in test directories` \ - | grep -e "\.py$" -e "\.java$" -e "\.scala$" `# include only code files` \ - | tr "\n" " " - ) - new_public_classes=$( - git diff master...$ghprbActualCommit ${source_files} `# diff patch against master from branch point` \ - | grep "^\+" `# filter in only added lines` \ - | sed -r -e "s/^\+//g" `# remove the leading +` \ - | grep -e "trait " -e "class " `# filter in lines with these key words` \ - | grep -e "{" -e "(" `# filter in lines with these key words, too` \ - | grep -v -e "\@\@" -e "private" `# exclude lines with these words` \ - | grep -v -e "^// " -e "^/\*" -e "^ \* " `# exclude comment lines` \ - | sed -r -e "s/\{.*//g" `# remove from the { onwards` \ - | sed -r -e "s/\}//g" `# just in case, remove }; they mess the JSON` \ - | sed -r -e "s/\"/\\\\\"/g" `# escape double quotes; they mess the JSON` \ - | sed -r -e "s/^(.*)$/\`\1\`/g" `# surround with backticks for style` \ - | sed -r -e "s/^/ \* /g" `# prepend ' *' to start of line` \ - | sed -r -e "s/$/\\\n/g" `# append newline to end of line` \ - | tr -d "\n" `# remove actual LF characters` - ) - - if [ -z "$new_public_classes" ]; then - public_classes_note=" * This patch adds no public classes." - else - public_classes_note=" * This patch adds the following public classes _(experimental)_:" - public_classes_note="${public_classes_note}\n${new_public_classes}" - fi -} +done # post start message { @@ -181,7 +170,6 @@ function send_archived_logs () { PR $ghprbPullId at commit [\`${SHORT_COMMIT_HASH}\`](${COMMIT_URL})." start_message="${start_message}\n${merge_note}" - # start_message="${start_message}\n${public_classes_note}" post_message "$start_message" } @@ -234,8 +222,7 @@ function send_archived_logs () { PR $ghprbPullId at commit [\`${SHORT_COMMIT_HASH}\`](${COMMIT_URL})." result_message="${result_message}\n${test_result_note}" - result_message="${result_message}\n${merge_note}" - result_message="${result_message}\n${public_classes_note}" + result_message="${result_message}\n${pr_message}" post_message "$result_message" } diff --git a/dev/tests/pr_merge_ability.sh b/dev/tests/pr_merge_ability.sh new file mode 100755 index 0000000000000..d9a347fe24a8c --- /dev/null +++ b/dev/tests/pr_merge_ability.sh @@ -0,0 +1,39 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# +# This script follows the base format for testing pull requests against +# another branch and returning results to be published. More details can be +# found at dev/run-tests-jenkins. +# +# Arg1: The Github Pull Request Actual Commit +#+ known as `ghprbActualCommit` in `run-tests-jenkins` +# Arg2: The SHA1 hash +#+ known as `sha1` in `run-tests-jenkins` +# + +ghprbActualCommit="$1" +sha1="$2" + +# check PR merge-ability +if [ "${sha1}" == "${ghprbActualCommit}" ]; then + echo " * This patch **does not merge cleanly**." +else + echo " * This patch merges cleanly." +fi diff --git a/dev/tests/pr_public_classes.sh b/dev/tests/pr_public_classes.sh new file mode 100755 index 0000000000000..927295b88c963 --- /dev/null +++ b/dev/tests/pr_public_classes.sh @@ -0,0 +1,65 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# +# This script follows the base format for testing pull requests against +# another branch and returning results to be published. More details can be +# found at dev/run-tests-jenkins. +# +# Arg1: The Github Pull Request Actual Commit +#+ known as `ghprbActualCommit` in `run-tests-jenkins` +# Arg2: The SHA1 hash +#+ known as `sha1` in `run-tests-jenkins` +# + +# We diff master...$ghprbActualCommit because that gets us changes introduced in the PR +#+ and not anything else added to master since the PR was branched. + +ghprbActualCommit="$1" +sha1="$2" + +source_files=$( + git diff master...$ghprbActualCommit --name-only `# diff patch against master from branch point` \ + | grep -v -e "\/test" `# ignore files in test directories` \ + | grep -e "\.py$" -e "\.java$" -e "\.scala$" `# include only code files` \ + | tr "\n" " " +) +new_public_classes=$( + git diff master...$ghprbActualCommit ${source_files} `# diff patch against master from branch point` \ + | grep "^\+" `# filter in only added lines` \ + | sed -r -e "s/^\+//g" `# remove the leading +` \ + | grep -e "trait " -e "class " `# filter in lines with these key words` \ + | grep -e "{" -e "(" `# filter in lines with these key words, too` \ + | grep -v -e "\@\@" -e "private" `# exclude lines with these words` \ + | grep -v -e "^// " -e "^/\*" -e "^ \* " `# exclude comment lines` \ + | sed -r -e "s/\{.*//g" `# remove from the { onwards` \ + | sed -r -e "s/\}//g" `# just in case, remove }; they mess the JSON` \ + | sed -r -e "s/\"/\\\\\"/g" `# escape double quotes; they mess the JSON` \ + | sed -r -e "s/^(.*)$/\`\1\`/g" `# surround with backticks for style` \ + | sed -r -e "s/^/ \* /g" `# prepend ' *' to start of line` \ + | sed -r -e "s/$/\\\n/g" `# append newline to end of line` \ + | tr -d "\n" `# remove actual LF characters` +) + +if [ -z "$new_public_classes" ]; then + echo " * This patch adds no public classes." +else + public_classes_note=" * This patch adds the following public classes _(experimental)_:" + echo "${public_classes_note}\n${new_public_classes}" +fi From 3b5aaa6a5fe0d838a8570c9d500ebca5f63764f8 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 19 Mar 2015 15:25:32 -0400 Subject: [PATCH 05/47] [Core][minor] remove unused `visitedStages` in `DAGScheduler.stageDependsOn` We define and update `visitedStages` in `DAGScheduler.stageDependsOn`, but never read it. So we can safely remove it. Author: Wenchen Fan Closes #5086 from cloud-fan/minor and squashes the following commits: 24663ea [Wenchen Fan] remove un-used variable --- .../main/scala/org/apache/spark/scheduler/DAGScheduler.scala | 2 -- 1 file changed, 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 1021172e6afb4..8feac6cb6b7a1 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -1262,7 +1262,6 @@ class DAGScheduler( return true } val visitedRdds = new HashSet[RDD[_]] - val visitedStages = new HashSet[Stage] // We are manually maintaining a stack here to prevent StackOverflowError // caused by recursively visiting val waitingForVisit = new Stack[RDD[_]] @@ -1274,7 +1273,6 @@ class DAGScheduler( case shufDep: ShuffleDependency[_, _, _] => val mapStage = getShuffleMapStage(shufDep, stage.jobId) if (!mapStage.isAvailable) { - visitedStages += mapStage waitingForVisit.push(mapStage.rdd) } // Otherwise there's no need to follow the dependency back case narrowDep: NarrowDependency[_] => From f17d43b033d928dbc46aef8e367aa08902e698ad Mon Sep 17 00:00:00 2001 From: Nicholas Chammas Date: Thu, 19 Mar 2015 12:46:10 -0700 Subject: [PATCH 06/47] [SPARK-6219] [Build] Check that Python code compiles This PR expands the Python lint checks so that they check for obvious compilation errors in our Python code. For example: ``` $ ./dev/lint-python Python lint checks failed. Compiling ./ec2/spark_ec2.py ... File "./ec2/spark_ec2.py", line 618 return (master_nodes,, slave_nodes) ^ SyntaxError: invalid syntax ./ec2/spark_ec2.py:618:25: E231 missing whitespace after ',' ./ec2/spark_ec2.py:1117:101: E501 line too long (102 > 100 characters) ``` This PR also bumps up the version of `pep8`. It ignores new types of checks introduced by that version bump while fixing problems missed by the older version of `pep8` we were using. Author: Nicholas Chammas Closes #4941 from nchammas/compile-spark-ec2 and squashes the following commits: 75e31d8 [Nicholas Chammas] upgrade pep8 + check compile b33651c [Nicholas Chammas] PEP8 line length --- dev/lint-python | 44 +++++++++++++++++++++++++++----------------- ec2/spark_ec2.py | 4 ++-- 2 files changed, 29 insertions(+), 19 deletions(-) diff --git a/dev/lint-python b/dev/lint-python index 772f856154ae0..fded654893a7c 100755 --- a/dev/lint-python +++ b/dev/lint-python @@ -19,43 +19,53 @@ SCRIPT_DIR="$( cd "$( dirname "$0" )" && pwd )" SPARK_ROOT_DIR="$(dirname "$SCRIPT_DIR")" -PEP8_REPORT_PATH="$SPARK_ROOT_DIR/dev/pep8-report.txt" +PATHS_TO_CHECK="./python/pyspark/ ./ec2/spark_ec2.py ./examples/src/main/python/" +PYTHON_LINT_REPORT_PATH="$SPARK_ROOT_DIR/dev/python-lint-report.txt" cd "$SPARK_ROOT_DIR" +# compileall: https://docs.python.org/2/library/compileall.html +python -B -m compileall -q -l $PATHS_TO_CHECK > "$PYTHON_LINT_REPORT_PATH" +compile_status="${PIPESTATUS[0]}" + # Get pep8 at runtime so that we don't rely on it being installed on the build server. #+ See: https://github.com/apache/spark/pull/1744#issuecomment-50982162 #+ TODOs: -#+ - Dynamically determine latest release version of pep8 and use that. -#+ - Download this from a more reliable source. (GitHub raw can be flaky, apparently. (?)) +#+ - Download pep8 from PyPI. It's more "official". PEP8_SCRIPT_PATH="$SPARK_ROOT_DIR/dev/pep8.py" -PEP8_SCRIPT_REMOTE_PATH="https://raw.githubusercontent.com/jcrocholl/pep8/1.5.7/pep8.py" -PEP8_PATHS_TO_CHECK="./python/pyspark/ ./ec2/spark_ec2.py ./examples/src/main/python/" +PEP8_SCRIPT_REMOTE_PATH="https://raw.githubusercontent.com/jcrocholl/pep8/1.6.2/pep8.py" +# if [ ! -e "$PEP8_SCRIPT_PATH" ]; then curl --silent -o "$PEP8_SCRIPT_PATH" "$PEP8_SCRIPT_REMOTE_PATH" -curl_status=$? +curl_status="$?" -if [ $curl_status -ne 0 ]; then +if [ "$curl_status" -ne 0 ]; then echo "Failed to download pep8.py from \"$PEP8_SCRIPT_REMOTE_PATH\"." - exit $curl_status + exit "$curl_status" fi - +# fi # There is no need to write this output to a file #+ first, but we do so so that the check status can #+ be output before the report, like with the #+ scalastyle and RAT checks. -python "$PEP8_SCRIPT_PATH" $PEP8_PATHS_TO_CHECK > "$PEP8_REPORT_PATH" -pep8_status=${PIPESTATUS[0]} #$? +python "$PEP8_SCRIPT_PATH" --ignore=E402,E731,E241,W503,E226 $PATHS_TO_CHECK >> "$PYTHON_LINT_REPORT_PATH" +pep8_status="${PIPESTATUS[0]}" + +if [ "$compile_status" -eq 0 -a "$pep8_status" -eq 0 ]; then + lint_status=0 +else + lint_status=1 +fi -if [ $pep8_status -ne 0 ]; then - echo "PEP 8 checks failed." - cat "$PEP8_REPORT_PATH" +if [ "$lint_status" -ne 0 ]; then + echo "Python lint checks failed." + cat "$PYTHON_LINT_REPORT_PATH" else - echo "PEP 8 checks passed." + echo "Python lint checks passed." fi -rm "$PEP8_REPORT_PATH" rm "$PEP8_SCRIPT_PATH" +rm "$PYTHON_LINT_REPORT_PATH" -exit $pep8_status +exit "$lint_status" diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py index f848874b0c775..c467cd08ed742 100755 --- a/ec2/spark_ec2.py +++ b/ec2/spark_ec2.py @@ -1159,8 +1159,8 @@ def real_main(): if EC2_INSTANCE_TYPES[opts.instance_type] != \ EC2_INSTANCE_TYPES[opts.master_instance_type]: print >> stderr, \ - "Error: spark-ec2 currently does not support having a master and slaves with " + \ - "different AMI virtualization types." + "Error: spark-ec2 currently does not support having a master and slaves " + \ + "with different AMI virtualization types." print >> stderr, "master instance virtualization type: {t}".format( t=EC2_INSTANCE_TYPES[opts.master_instance_type]) print >> stderr, "slave instance virtualization type: {t}".format( From 0745a305fac622a6eeb8aa4a7401205a14252939 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 19 Mar 2015 22:12:01 -0400 Subject: [PATCH 07/47] Tighten up field/method visibility in Executor and made some code more clear to read. I was reading Executor just now and found that some latest changes introduced some weird code path with too much monadic chaining and unnecessary fields. I cleaned it up a bit, and also tightened up the visibility of various fields/methods. Also added some inline documentation to help understand this code better. Author: Reynold Xin Closes #4850 from rxin/executor and squashes the following commits: 866fc60 [Reynold Xin] Code review feedback. 020efbb [Reynold Xin] Tighten up field/method visibility in Executor and made some code more clear to read. --- .../org/apache/spark/TaskEndReason.scala | 6 +- .../executor/CommitDeniedException.scala | 6 +- .../org/apache/spark/executor/Executor.scala | 196 ++++++++++-------- .../spark/executor/ExecutorSource.scala | 16 +- .../org/apache/spark/scheduler/Task.scala | 2 +- 5 files changed, 120 insertions(+), 106 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/TaskEndReason.scala b/core/src/main/scala/org/apache/spark/TaskEndReason.scala index 29a5cd5fdac76..48fd3e7e23d52 100644 --- a/core/src/main/scala/org/apache/spark/TaskEndReason.scala +++ b/core/src/main/scala/org/apache/spark/TaskEndReason.scala @@ -151,11 +151,7 @@ case object TaskKilled extends TaskFailedReason { * Task requested the driver to commit, but was denied. */ @DeveloperApi -case class TaskCommitDenied( - jobID: Int, - partitionID: Int, - attemptID: Int) - extends TaskFailedReason { +case class TaskCommitDenied(jobID: Int, partitionID: Int, attemptID: Int) extends TaskFailedReason { override def toErrorString: String = s"TaskCommitDenied (Driver denied task commit)" + s" for job: $jobID, partition: $partitionID, attempt: $attemptID" } diff --git a/core/src/main/scala/org/apache/spark/executor/CommitDeniedException.scala b/core/src/main/scala/org/apache/spark/executor/CommitDeniedException.scala index f7604a321f007..f47d7ef511da1 100644 --- a/core/src/main/scala/org/apache/spark/executor/CommitDeniedException.scala +++ b/core/src/main/scala/org/apache/spark/executor/CommitDeniedException.scala @@ -22,14 +22,12 @@ import org.apache.spark.{TaskCommitDenied, TaskEndReason} /** * Exception thrown when a task attempts to commit output to HDFS but is denied by the driver. */ -class CommitDeniedException( +private[spark] class CommitDeniedException( msg: String, jobID: Int, splitID: Int, attemptID: Int) extends Exception(msg) { - def toTaskEndReason: TaskEndReason = new TaskCommitDenied(jobID, splitID, attemptID) - + def toTaskEndReason: TaskEndReason = TaskCommitDenied(jobID, splitID, attemptID) } - diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 6196f7b165049..bf3135ef081c1 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -21,7 +21,7 @@ import java.io.File import java.lang.management.ManagementFactory import java.net.URL import java.nio.ByteBuffer -import java.util.concurrent._ +import java.util.concurrent.ConcurrentHashMap import scala.collection.JavaConversions._ import scala.collection.mutable.{ArrayBuffer, HashMap} @@ -31,15 +31,17 @@ import akka.actor.Props import org.apache.spark._ import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.scheduler._ +import org.apache.spark.scheduler.{DirectTaskResult, IndirectTaskResult, Task} import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.storage.{StorageLevel, TaskResultBlockId} -import org.apache.spark.util.{ChildFirstURLClassLoader, MutableURLClassLoader, - SparkUncaughtExceptionHandler, AkkaUtils, Utils} +import org.apache.spark.util._ /** - * Spark executor used with Mesos, YARN, and the standalone scheduler. - * In coarse-grained mode, an existing actor system is provided. + * Spark executor, backed by a threadpool to run tasks. + * + * This can be used with Mesos, YARN, and the standalone scheduler. + * An internal RPC interface (at the moment Akka) is used for communication with the driver, + * except in the case of Mesos fine-grained mode. */ private[spark] class Executor( executorId: String, @@ -47,8 +49,8 @@ private[spark] class Executor( env: SparkEnv, userClassPath: Seq[URL] = Nil, isLocal: Boolean = false) - extends Logging -{ + extends Logging { + logInfo(s"Starting executor ID $executorId on host $executorHostname") // Application dependencies (added through SparkContext) that we've fetched so far on this node. @@ -78,9 +80,8 @@ private[spark] class Executor( } // Start worker thread pool - val threadPool = Utils.newDaemonCachedThreadPool("Executor task launch worker") - - val executorSource = new ExecutorSource(this, executorId) + private val threadPool = Utils.newDaemonCachedThreadPool("Executor task launch worker") + private val executorSource = new ExecutorSource(threadPool, executorId) if (!isLocal) { env.metricsSystem.registerSource(executorSource) @@ -122,21 +123,21 @@ private[spark] class Executor( taskId: Long, attemptNumber: Int, taskName: String, - serializedTask: ByteBuffer) { + serializedTask: ByteBuffer): Unit = { val tr = new TaskRunner(context, taskId = taskId, attemptNumber = attemptNumber, taskName, serializedTask) runningTasks.put(taskId, tr) threadPool.execute(tr) } - def killTask(taskId: Long, interruptThread: Boolean) { + def killTask(taskId: Long, interruptThread: Boolean): Unit = { val tr = runningTasks.get(taskId) if (tr != null) { tr.kill(interruptThread) } } - def stop() { + def stop(): Unit = { env.metricsSystem.report() env.actorSystem.stop(executorActor) isStopped = true @@ -146,7 +147,10 @@ private[spark] class Executor( } } - private def gcTime = ManagementFactory.getGarbageCollectorMXBeans.map(_.getCollectionTime).sum + /** Returns the total amount of time this JVM process has spent in garbage collection. */ + private def computeTotalGcTime(): Long = { + ManagementFactory.getGarbageCollectorMXBeans.map(_.getCollectionTime).sum + } class TaskRunner( execBackend: ExecutorBackend, @@ -156,12 +160,19 @@ private[spark] class Executor( serializedTask: ByteBuffer) extends Runnable { + /** Whether this task has been killed. */ @volatile private var killed = false - @volatile var task: Task[Any] = _ - @volatile var attemptedTask: Option[Task[Any]] = None + + /** How much the JVM process has spent in GC when the task starts to run. */ @volatile var startGCTime: Long = _ - def kill(interruptThread: Boolean) { + /** + * The task to run. This will be set in run() by deserializing the task binary coming + * from the driver. Once it is set, it will never be changed. + */ + @volatile var task: Task[Any] = _ + + def kill(interruptThread: Boolean): Unit = { logInfo(s"Executor is trying to kill $taskName (TID $taskId)") killed = true if (task != null) { @@ -169,14 +180,14 @@ private[spark] class Executor( } } - override def run() { + override def run(): Unit = { val deserializeStartTime = System.currentTimeMillis() Thread.currentThread.setContextClassLoader(replClassLoader) val ser = env.closureSerializer.newInstance() logInfo(s"Running $taskName (TID $taskId)") execBackend.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER) var taskStart: Long = 0 - startGCTime = gcTime + startGCTime = computeTotalGcTime() try { val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(serializedTask) @@ -193,7 +204,6 @@ private[spark] class Executor( throw new TaskKilledException } - attemptedTask = Some(task) logDebug("Task " + taskId + "'s epoch is " + task.epoch) env.mapOutputTracker.updateEpoch(task.epoch) @@ -215,18 +225,17 @@ private[spark] class Executor( for (m <- task.metrics) { m.setExecutorDeserializeTime(taskStart - deserializeStartTime) m.setExecutorRunTime(taskFinish - taskStart) - m.setJvmGCTime(gcTime - startGCTime) + m.setJvmGCTime(computeTotalGcTime() - startGCTime) m.setResultSerializationTime(afterSerialization - beforeSerialization) } val accumUpdates = Accumulators.values - val directResult = new DirectTaskResult(valueBytes, accumUpdates, task.metrics.orNull) val serializedDirectResult = ser.serialize(directResult) val resultSize = serializedDirectResult.limit // directSend = sending directly back to the driver - val serializedResult = { + val serializedResult: ByteBuffer = { if (maxResultSize > 0 && resultSize > maxResultSize) { logWarning(s"Finished $taskName (TID $taskId). Result is larger than maxResultSize " + s"(${Utils.bytesToString(resultSize)} > ${Utils.bytesToString(maxResultSize)}), " + @@ -248,42 +257,40 @@ private[spark] class Executor( execBackend.statusUpdate(taskId, TaskState.FINISHED, serializedResult) } catch { - case ffe: FetchFailedException => { + case ffe: FetchFailedException => val reason = ffe.toTaskEndReason execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason)) - } - case _: TaskKilledException | _: InterruptedException if task.killed => { + case _: TaskKilledException | _: InterruptedException if task.killed => logInfo(s"Executor killed $taskName (TID $taskId)") execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled)) - } - case cDE: CommitDeniedException => { + case cDE: CommitDeniedException => val reason = cDE.toTaskEndReason execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason)) - } - case t: Throwable => { + case t: Throwable => // Attempt to exit cleanly by informing the driver of our failure. // If anything goes wrong (or this was a fatal exception), we will delegate to // the default uncaught exception handler, which will terminate the Executor. logError(s"Exception in $taskName (TID $taskId)", t) - val serviceTime = System.currentTimeMillis() - taskStart - val metrics = attemptedTask.flatMap(t => t.metrics) - for (m <- metrics) { - m.setExecutorRunTime(serviceTime) - m.setJvmGCTime(gcTime - startGCTime) + val metrics: Option[TaskMetrics] = Option(task).flatMap { task => + task.metrics.map { m => + m.setExecutorRunTime(System.currentTimeMillis() - taskStart) + m.setJvmGCTime(computeTotalGcTime() - startGCTime) + m + } } - val reason = new ExceptionFailure(t, metrics) - execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason)) + val taskEndReason = new ExceptionFailure(t, metrics) + execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(taskEndReason)) // Don't forcibly exit unless the exception was inherently fatal, to avoid // stopping other tasks unnecessarily. if (Utils.isFatalError(t)) { SparkUncaughtExceptionHandler.uncaughtException(t) } - } + } finally { // Release memory used by this thread for shuffles env.shuffleMemoryManager.releaseMemoryForThisThread() @@ -358,7 +365,7 @@ private[spark] class Executor( for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) { logInfo("Fetching " + name + " with timestamp " + timestamp) // Fetch file with useCache mode, close cache for local mode. - Utils.fetchFile(name, new File(SparkFiles.getRootDirectory), conf, + Utils.fetchFile(name, new File(SparkFiles.getRootDirectory()), conf, env.securityManager, hadoopConf, timestamp, useCache = !isLocal) currentFiles(name) = timestamp } @@ -370,12 +377,12 @@ private[spark] class Executor( if (currentTimeStamp < timestamp) { logInfo("Fetching " + name + " with timestamp " + timestamp) // Fetch file with useCache mode, close cache for local mode. - Utils.fetchFile(name, new File(SparkFiles.getRootDirectory), conf, + Utils.fetchFile(name, new File(SparkFiles.getRootDirectory()), conf, env.securityManager, hadoopConf, timestamp, useCache = !isLocal) currentJars(name) = timestamp // Add it to our class loader - val url = new File(SparkFiles.getRootDirectory, localName).toURI.toURL - if (!urlClassLoader.getURLs.contains(url)) { + val url = new File(SparkFiles.getRootDirectory(), localName).toURI.toURL + if (!urlClassLoader.getURLs().contains(url)) { logInfo("Adding " + url + " to class loader") urlClassLoader.addURL(url) } @@ -384,61 +391,70 @@ private[spark] class Executor( } } - def startDriverHeartbeater() { - val interval = conf.getInt("spark.executor.heartbeatInterval", 10000) - val timeout = AkkaUtils.lookupTimeout(conf) - val retryAttempts = AkkaUtils.numRetries(conf) - val retryIntervalMs = AkkaUtils.retryWaitMs(conf) - val heartbeatReceiverRef = AkkaUtils.makeDriverRef("HeartbeatReceiver", conf, env.actorSystem) + private val timeout = AkkaUtils.lookupTimeout(conf) + private val retryAttempts = AkkaUtils.numRetries(conf) + private val retryIntervalMs = AkkaUtils.retryWaitMs(conf) + private val heartbeatReceiverRef = + AkkaUtils.makeDriverRef("HeartbeatReceiver", conf, env.actorSystem) + + /** Reports heartbeat and metrics for active tasks to the driver. */ + private def reportHeartBeat(): Unit = { + // list of (task id, metrics) to send back to the driver + val tasksMetrics = new ArrayBuffer[(Long, TaskMetrics)]() + val curGCTime = computeTotalGcTime() + + for (taskRunner <- runningTasks.values()) { + if (taskRunner.task != null) { + taskRunner.task.metrics.foreach { metrics => + metrics.updateShuffleReadMetrics() + metrics.updateInputMetrics() + metrics.setJvmGCTime(curGCTime - taskRunner.startGCTime) + + if (isLocal) { + // JobProgressListener will hold an reference of it during + // onExecutorMetricsUpdate(), then JobProgressListener can not see + // the changes of metrics any more, so make a deep copy of it + val copiedMetrics = Utils.deserialize[TaskMetrics](Utils.serialize(metrics)) + tasksMetrics += ((taskRunner.taskId, copiedMetrics)) + } else { + // It will be copied by serialization + tasksMetrics += ((taskRunner.taskId, metrics)) + } + } + } + } - val t = new Thread() { + val message = Heartbeat(executorId, tasksMetrics.toArray, env.blockManager.blockManagerId) + try { + val response = AkkaUtils.askWithReply[HeartbeatResponse](message, heartbeatReceiverRef, + retryAttempts, retryIntervalMs, timeout) + if (response.reregisterBlockManager) { + logWarning("Told to re-register on heartbeat") + env.blockManager.reregister() + } + } catch { + case NonFatal(e) => logWarning("Issue communicating with driver in heartbeater", e) + } + } + + /** + * Starts a thread to report heartbeat and partial metrics for active tasks to driver. + * This thread stops running when the executor is stopped. + */ + private def startDriverHeartbeater(): Unit = { + val interval = conf.getInt("spark.executor.heartbeatInterval", 10000) + val thread = new Thread() { override def run() { // Sleep a random interval so the heartbeats don't end up in sync Thread.sleep(interval + (math.random * interval).asInstanceOf[Int]) - while (!isStopped) { - val tasksMetrics = new ArrayBuffer[(Long, TaskMetrics)]() - val curGCTime = gcTime - - for (taskRunner <- runningTasks.values()) { - if (taskRunner.attemptedTask.nonEmpty) { - Option(taskRunner.task).flatMap(_.metrics).foreach { metrics => - metrics.updateShuffleReadMetrics() - metrics.updateInputMetrics() - metrics.setJvmGCTime(curGCTime - taskRunner.startGCTime) - - if (isLocal) { - // JobProgressListener will hold an reference of it during - // onExecutorMetricsUpdate(), then JobProgressListener can not see - // the changes of metrics any more, so make a deep copy of it - val copiedMetrics = Utils.deserialize[TaskMetrics](Utils.serialize(metrics)) - tasksMetrics += ((taskRunner.taskId, copiedMetrics)) - } else { - // It will be copied by serialization - tasksMetrics += ((taskRunner.taskId, metrics)) - } - } - } - } - - val message = Heartbeat(executorId, tasksMetrics.toArray, env.blockManager.blockManagerId) - try { - val response = AkkaUtils.askWithReply[HeartbeatResponse](message, heartbeatReceiverRef, - retryAttempts, retryIntervalMs, timeout) - if (response.reregisterBlockManager) { - logWarning("Told to re-register on heartbeat") - env.blockManager.reregister() - } - } catch { - case NonFatal(t) => logWarning("Issue communicating with driver in heartbeater", t) - } - + reportHeartBeat() Thread.sleep(interval) } } } - t.setDaemon(true) - t.setName("Driver Heartbeater") - t.start() + thread.setDaemon(true) + thread.setName("driver-heartbeater") + thread.start() } } diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala index c4d73622c4727..293c512f8b70c 100644 --- a/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala +++ b/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala @@ -17,6 +17,8 @@ package org.apache.spark.executor +import java.util.concurrent.ThreadPoolExecutor + import scala.collection.JavaConversions._ import com.codahale.metrics.{Gauge, MetricRegistry} @@ -24,9 +26,11 @@ import org.apache.hadoop.fs.FileSystem import org.apache.spark.metrics.source.Source -private[spark] class ExecutorSource(val executor: Executor, executorId: String) extends Source { +private[spark] +class ExecutorSource(threadPool: ThreadPoolExecutor, executorId: String) extends Source { + private def fileStats(scheme: String) : Option[FileSystem.Statistics] = - FileSystem.getAllStatistics().filter(s => s.getScheme.equals(scheme)).headOption + FileSystem.getAllStatistics().find(s => s.getScheme.equals(scheme)) private def registerFileSystemStat[T]( scheme: String, name: String, f: FileSystem.Statistics => T, defaultValue: T) = { @@ -41,23 +45,23 @@ private[spark] class ExecutorSource(val executor: Executor, executorId: String) // Gauge for executor thread pool's actively executing task counts metricRegistry.register(MetricRegistry.name("threadpool", "activeTasks"), new Gauge[Int] { - override def getValue: Int = executor.threadPool.getActiveCount() + override def getValue: Int = threadPool.getActiveCount() }) // Gauge for executor thread pool's approximate total number of tasks that have been completed metricRegistry.register(MetricRegistry.name("threadpool", "completeTasks"), new Gauge[Long] { - override def getValue: Long = executor.threadPool.getCompletedTaskCount() + override def getValue: Long = threadPool.getCompletedTaskCount() }) // Gauge for executor thread pool's current number of threads metricRegistry.register(MetricRegistry.name("threadpool", "currentPool_size"), new Gauge[Int] { - override def getValue: Int = executor.threadPool.getPoolSize() + override def getValue: Int = threadPool.getPoolSize() }) // Gauge got executor thread pool's largest number of threads that have ever simultaneously // been in th pool metricRegistry.register(MetricRegistry.name("threadpool", "maxPool_size"), new Gauge[Int] { - override def getValue: Int = executor.threadPool.getMaximumPoolSize() + override def getValue: Int = threadPool.getMaximumPoolSize() }) // Gauge for file system stats of this executor diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 847a4912eec13..4d9f940813b8e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -45,7 +45,7 @@ import org.apache.spark.util.Utils private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) extends Serializable { /** - * Called by Executor to run this task. + * Called by [[Executor]] to run this task. * * @param taskAttemptId an identifier for this task attempt that is unique within a SparkContext. * @param attemptNumber how many times this task has been attempted (0 for the first attempt) From 116c553fd6f6d2adcbbf000cd80b5c46d4516e87 Mon Sep 17 00:00:00 2001 From: Jongyoul Lee Date: Fri, 20 Mar 2015 12:24:34 +0000 Subject: [PATCH 08/47] [SPARK-6286][Mesos][minor] Handle missing Mesos case TASK_ERROR - Made TaskState.isFailed for handling TASK_LOST and TASK_ERROR and synchronizing CoarseMesosSchedulerBackend and MesosSchedulerBackend - This is related #5000 Author: Jongyoul Lee Closes #5088 from jongyoul/SPARK-6286-1 and squashes the following commits: 4f2362f [Jongyoul Lee] [SPARK-6286][Mesos][minor] Handle missing Mesos case TASK_ERROR - Fixed scalastyle ac4336a [Jongyoul Lee] [SPARK-6286][Mesos][minor] Handle missing Mesos case TASK_ERROR - Made TaskState.isFailed for handling TASK_LOST and TASK_ERROR and synchronizing CoarseMesosSchedulerBackend and MesosSchedulerBackend --- core/src/main/scala/org/apache/spark/TaskState.scala | 2 ++ .../scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala | 2 +- .../spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala | 3 ++- 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/TaskState.scala b/core/src/main/scala/org/apache/spark/TaskState.scala index d85a6d683427d..c415fe99b105e 100644 --- a/core/src/main/scala/org/apache/spark/TaskState.scala +++ b/core/src/main/scala/org/apache/spark/TaskState.scala @@ -27,6 +27,8 @@ private[spark] object TaskState extends Enumeration { type TaskState = Value + def isFailed(state: TaskState) = (LOST == state) || (FAILED == state) + def isFinished(state: TaskState) = FINISHED_STATES.contains(state) def toMesos(state: TaskState): MesosTaskState = state match { diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index fc92b9c35c3a3..e13de0f46ef89 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -277,7 +277,7 @@ private[spark] class CoarseMesosSchedulerBackend( coresByTaskId -= taskId } // If it was a failure, mark the slave as failed for blacklisting purposes - if (state == MesosTaskState.TASK_FAILED || state == MesosTaskState.TASK_LOST) { + if (TaskState.isFailed(TaskState.fromMesos(state))) { failuresBySlaveId(slaveId) = failuresBySlaveId.getOrElse(slaveId, 0) + 1 if (failuresBySlaveId(slaveId) >= MAX_SLAVE_FAILURES) { logInfo("Blacklisting Mesos slave " + slaveId + " due to too many failures; " + diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala index df8f4306b88a8..06bb527522141 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala @@ -318,7 +318,8 @@ private[spark] class MesosSchedulerBackend( val tid = status.getTaskId.getValue.toLong val state = TaskState.fromMesos(status.getState) synchronized { - if (status.getState == MesosTaskState.TASK_LOST && taskIdToSlaveId.contains(tid)) { + if (TaskState.isFailed(TaskState.fromMesos(status.getState)) + && taskIdToSlaveId.contains(tid)) { // We lost the executor on this slave, so remember that it's gone removeExecutor(taskIdToSlaveId(tid), "Lost executor") } From d08e3eb3dc455970b685a7b8b7e00c537c89a8e4 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Fri, 20 Mar 2015 14:14:53 +0000 Subject: [PATCH 09/47] SPARK-5134 [BUILD] Bump default Hadoop version to 2+ Bump default Hadoop version to 2.2.0. (This is already the dependency version reported by published Maven artifacts.) See JIRA for further discussion. Author: Sean Owen Closes #5027 from srowen/SPARK-5134 and squashes the following commits: acbee14 [Sean Owen] Bump default Hadoop version to 2.2.0. (This is already the dependency version reported by published Maven artifacts.) --- pom.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pom.xml b/pom.xml index 6fc56a86d44ac..efb9f172f4751 100644 --- a/pom.xml +++ b/pom.xml @@ -120,7 +120,7 @@ shaded-protobuf 1.7.10 1.2.17 - 1.0.4 + 2.2.0 2.4.1 ${hadoop.version} 0.98.7-hadoop1 From 6f80c3e8880340597f161f87e64697bec86cc586 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Fri, 20 Mar 2015 14:16:21 +0000 Subject: [PATCH 10/47] SPARK-6338 [CORE] Use standard temp dir mechanisms in tests to avoid orphaned temp files Use `Utils.createTempDir()` to replace other temp file mechanisms used in some tests, to further ensure they are cleaned up, and simplify Author: Sean Owen Closes #5029 from srowen/SPARK-6338 and squashes the following commits: 27b740a [Sean Owen] Fix hive-thriftserver tests that don't expect an existing dir 4a212fa [Sean Owen] Standardize a bit more temp dir management 9004081 [Sean Owen] Revert some added recursive-delete calls 57609e4 [Sean Owen] Use Utils.createTempDir() to replace other temp file mechanisms used in some tests, to further ensure they are cleaned up, and simplify --- .../spark/deploy/FaultToleranceTest.scala | 4 ++-- .../scala/org/apache/spark/util/Utils.scala | 2 +- .../org/apache/spark/CheckpointSuite.scala | 3 +-- .../apache/spark/SecurityManagerSuite.scala | 5 +++-- .../org/apache/spark/SparkContextSuite.scala | 11 +++++----- .../spark/deploy/SparkSubmitSuite.scala | 8 ++++--- .../org/apache/spark/rdd/PipedRDDSuite.scala | 6 +++-- .../storage/BlockObjectWriterSuite.scala | 16 +++++++------- .../apache/spark/util/FileAppenderSuite.scala | 2 +- .../org/apache/spark/util/UtilsSuite.scala | 7 +++--- .../kafka/ReliableKafkaStreamSuite.scala | 10 +++------ .../org/apache/spark/graphx/GraphSuite.scala | 6 ++--- .../org/apache/spark/repl/ReplSuite.scala | 5 +---- .../expressions/codegen/package.scala | 4 ++-- .../spark/sql/catalyst/util/package.scala | 15 ++----------- .../spark/sql/parquet/ParquetTest.scala | 6 ++--- .../spark/sql/UserDefinedTypeSuite.scala | 6 +++-- .../org/apache/spark/sql/json/JsonSuite.scala | 22 +++++++++++-------- .../sources/CreateTableAsSelectSuite.scala | 5 ++--- .../spark/sql/sources/InsertSuite.scala | 5 ++--- .../spark/sql/sources/SaveLoadSuite.scala | 6 ++--- .../sql/hive/thriftserver/CliSuite.scala | 8 ++++--- .../HiveThriftServer2Suites.scala | 7 +++--- .../apache/spark/sql/hive/test/TestHive.scala | 16 +++++--------- .../sql/hive/InsertIntoHiveTableSuite.scala | 5 ++--- .../sql/hive/MetastoreDataSourcesSuite.scala | 14 ++++++------ .../apache/spark/sql/hive/parquetSuites.scala | 22 +++++-------------- .../spark/streaming/CheckpointSuite.scala | 6 ++--- .../apache/spark/streaming/FailureSuite.scala | 10 ++++----- .../streaming/ReceivedBlockHandlerSuite.scala | 11 +++------- .../streaming/ReceivedBlockTrackerSuite.scala | 9 ++------ .../spark/streaming/ReceiverSuite.scala | 5 ++--- .../yarn/YarnSparkHadoopUtilSuite.scala | 2 +- 33 files changed, 116 insertions(+), 153 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala b/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala index 4e58aa0ed4c7e..5668b53fc6f4f 100644 --- a/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala +++ b/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala @@ -33,6 +33,7 @@ import org.json4s.jackson.JsonMethods import org.apache.spark.{Logging, SparkConf, SparkContext} import org.apache.spark.deploy.master.{RecoveryState, SparkCuratorUtil} +import org.apache.spark.util.Utils /** * This suite tests the fault tolerance of the Spark standalone scheduler, mainly the Master. @@ -405,8 +406,7 @@ private object SparkDocker { private def startNode(dockerCmd: ProcessBuilder) : (String, DockerId, File) = { val ipPromise = promise[String]() - val outFile = File.createTempFile("fault-tolerance-test", "") - outFile.deleteOnExit() + val outFile = File.createTempFile("fault-tolerance-test", "", Utils.createTempDir()) val outStream: FileWriter = new FileWriter(outFile) def findIpAndLog(line: String): Unit = { if (line.startsWith("CONTAINER_IP=")) { diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 91aa70870ab20..fa56bb09e4e5c 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -288,7 +288,7 @@ private[spark] object Utils extends Logging { } catch { case e: SecurityException => dir = null; } } - dir + dir.getCanonicalFile } /** diff --git a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala index 3b10b3a042317..32abc65385267 100644 --- a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala +++ b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala @@ -33,8 +33,7 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging { override def beforeEach() { super.beforeEach() - checkpointDir = File.createTempFile("temp", "") - checkpointDir.deleteOnExit() + checkpointDir = File.createTempFile("temp", "", Utils.createTempDir()) checkpointDir.delete() sc = new SparkContext("local", "test") sc.setCheckpointDir(checkpointDir.toString) diff --git a/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala b/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala index 43fbd3ff3f756..62cb7649c0284 100644 --- a/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala @@ -21,6 +21,8 @@ import java.io.File import org.scalatest.FunSuite +import org.apache.spark.util.Utils + class SecurityManagerSuite extends FunSuite { test("set security with conf") { @@ -160,8 +162,7 @@ class SecurityManagerSuite extends FunSuite { } test("ssl off setup") { - val file = File.createTempFile("SSLOptionsSuite", "conf") - file.deleteOnExit() + val file = File.createTempFile("SSLOptionsSuite", "conf", Utils.createTempDir()) System.setProperty("spark.ssl.configFile", file.getAbsolutePath) val conf = new SparkConf() diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala index b8e3e83b5a47b..b07c4d93db4e6 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala @@ -79,13 +79,14 @@ class SparkContextSuite extends FunSuite with LocalSparkContext { val byteArray2 = converter.convert(bytesWritable) assert(byteArray2.length === 0) } - + test("addFile works") { - val file1 = File.createTempFile("someprefix1", "somesuffix1") + val dir = Utils.createTempDir() + + val file1 = File.createTempFile("someprefix1", "somesuffix1", dir) val absolutePath1 = file1.getAbsolutePath - val pluto = Utils.createTempDir() - val file2 = File.createTempFile("someprefix2", "somesuffix2", pluto) + val file2 = File.createTempFile("someprefix2", "somesuffix2", dir) val relativePath = file2.getParent + "/../" + file2.getParentFile.getName + "/" + file2.getName val absolutePath2 = file2.getAbsolutePath @@ -129,7 +130,7 @@ class SparkContextSuite extends FunSuite with LocalSparkContext { sc.stop() } } - + test("addFile recursive works") { val pluto = Utils.createTempDir() val neptune = Utils.createTempDir(pluto.getAbsolutePath) diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 46d745c4ecbfa..4561e5b8e9663 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -402,8 +402,10 @@ class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties val archives = "file:/archive1,archive2" // spark.yarn.dist.archives val pyFiles = "py-file1,py-file2" // spark.submit.pyFiles + val tmpDir = Utils.createTempDir() + // Test jars and files - val f1 = File.createTempFile("test-submit-jars-files", "") + val f1 = File.createTempFile("test-submit-jars-files", "", tmpDir) val writer1 = new PrintWriter(f1) writer1.println("spark.jars " + jars) writer1.println("spark.files " + files) @@ -420,7 +422,7 @@ class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties sysProps("spark.files") should be(Utils.resolveURIs(files)) // Test files and archives (Yarn) - val f2 = File.createTempFile("test-submit-files-archives", "") + val f2 = File.createTempFile("test-submit-files-archives", "", tmpDir) val writer2 = new PrintWriter(f2) writer2.println("spark.yarn.dist.files " + files) writer2.println("spark.yarn.dist.archives " + archives) @@ -437,7 +439,7 @@ class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties sysProps2("spark.yarn.dist.archives") should be(Utils.resolveURIs(archives)) // Test python files - val f3 = File.createTempFile("test-submit-python-files", "") + val f3 = File.createTempFile("test-submit-python-files", "", tmpDir) val writer3 = new PrintWriter(f3) writer3.println("spark.submit.pyFiles " + pyFiles) writer3.close() diff --git a/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala index 1a9a0e857e546..aea76c1adcc09 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala @@ -22,7 +22,6 @@ import java.io.File import org.apache.hadoop.fs.Path import org.apache.hadoop.io.{LongWritable, Text} import org.apache.hadoop.mapred.{FileSplit, JobConf, TextInputFormat} -import org.apache.spark._ import org.scalatest.FunSuite import scala.collection.Map @@ -30,6 +29,9 @@ import scala.language.postfixOps import scala.sys.process._ import scala.util.Try +import org.apache.spark._ +import org.apache.spark.util.Utils + class PipedRDDSuite extends FunSuite with SharedSparkContext { test("basic pipe") { @@ -141,7 +143,7 @@ class PipedRDDSuite extends FunSuite with SharedSparkContext { // make sure symlinks were created assert(pipedLs.length > 0) // clean up top level tasks directory - new File("tasks").delete() + Utils.deleteRecursively(new File("tasks")) } else { assert(true) } diff --git a/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala index c21c92b63ad13..78bbc4ec2c620 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala @@ -16,16 +16,18 @@ */ package org.apache.spark.storage -import org.scalatest.FunSuite import java.io.File + +import org.scalatest.FunSuite + +import org.apache.spark.SparkConf import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.serializer.JavaSerializer -import org.apache.spark.SparkConf +import org.apache.spark.util.Utils class BlockObjectWriterSuite extends FunSuite { test("verify write metrics") { - val file = new File("somefile") - file.deleteOnExit() + val file = new File(Utils.createTempDir(), "somefile") val writeMetrics = new ShuffleWriteMetrics() val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file, new JavaSerializer(new SparkConf()), 1024, os => os, true, writeMetrics) @@ -47,8 +49,7 @@ class BlockObjectWriterSuite extends FunSuite { } test("verify write metrics on revert") { - val file = new File("somefile") - file.deleteOnExit() + val file = new File(Utils.createTempDir(), "somefile") val writeMetrics = new ShuffleWriteMetrics() val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file, new JavaSerializer(new SparkConf()), 1024, os => os, true, writeMetrics) @@ -71,8 +72,7 @@ class BlockObjectWriterSuite extends FunSuite { } test("Reopening a closed block writer") { - val file = new File("somefile") - file.deleteOnExit() + val file = new File(Utils.createTempDir(), "somefile") val writeMetrics = new ShuffleWriteMetrics() val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file, new JavaSerializer(new SparkConf()), 1024, os => os, true, writeMetrics) diff --git a/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala b/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala index 4dc5b6103db74..43b6a405cb68c 100644 --- a/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala @@ -32,7 +32,7 @@ import org.apache.spark.util.logging.{RollingFileAppender, SizeBasedRollingPolic class FileAppenderSuite extends FunSuite with BeforeAndAfter with Logging { - val testFile = new File("FileAppenderSuite-test-" + System.currentTimeMillis).getAbsoluteFile + val testFile = new File(Utils.createTempDir(), "FileAppenderSuite-test").getAbsoluteFile before { cleanup() diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index b91428efadfd0..5d93086082189 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -122,7 +122,6 @@ class UtilsSuite extends FunSuite with ResetSystemProperties { test("reading offset bytes of a file") { val tmpDir2 = Utils.createTempDir() - tmpDir2.deleteOnExit() val f1Path = tmpDir2 + "/f1" val f1 = new FileOutputStream(f1Path) f1.write("1\n2\n3\n4\n5\n6\n7\n8\n9\n".getBytes(UTF_8)) @@ -151,7 +150,6 @@ class UtilsSuite extends FunSuite with ResetSystemProperties { test("reading offset bytes across multiple files") { val tmpDir = Utils.createTempDir() - tmpDir.deleteOnExit() val files = (1 to 3).map(i => new File(tmpDir, i.toString)) Files.write("0123456789", files(0), UTF_8) Files.write("abcdefghij", files(1), UTF_8) @@ -357,7 +355,8 @@ class UtilsSuite extends FunSuite with ResetSystemProperties { } test("loading properties from file") { - val outFile = File.createTempFile("test-load-spark-properties", "test") + val tmpDir = Utils.createTempDir() + val outFile = File.createTempFile("test-load-spark-properties", "test", tmpDir) try { System.setProperty("spark.test.fileNameLoadB", "2") Files.write("spark.test.fileNameLoadA true\n" + @@ -370,7 +369,7 @@ class UtilsSuite extends FunSuite with ResetSystemProperties { assert(sparkConf.getBoolean("spark.test.fileNameLoadA", false) === true) assert(sparkConf.getInt("spark.test.fileNameLoadB", 1) === 2) } finally { - outFile.delete() + Utils.deleteRecursively(tmpDir) } } diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala index fc53c23abda85..3cd960d1fd1d4 100644 --- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala +++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala @@ -25,16 +25,15 @@ import scala.concurrent.duration._ import scala.language.postfixOps import scala.util.Random -import com.google.common.io.Files import kafka.serializer.StringDecoder import kafka.utils.{ZKGroupTopicDirs, ZkUtils} -import org.apache.commons.io.FileUtils import org.scalatest.BeforeAndAfter import org.scalatest.concurrent.Eventually import org.apache.spark.SparkConf import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.{Milliseconds, StreamingContext} +import org.apache.spark.util.Utils class ReliableKafkaStreamSuite extends KafkaStreamSuiteBase with BeforeAndAfter with Eventually { @@ -60,7 +59,7 @@ class ReliableKafkaStreamSuite extends KafkaStreamSuiteBase with BeforeAndAfter ) ssc = new StreamingContext(sparkConf, Milliseconds(500)) - tempDirectory = Files.createTempDir() + tempDirectory = Utils.createTempDir() ssc.checkpoint(tempDirectory.getAbsolutePath) } @@ -68,10 +67,7 @@ class ReliableKafkaStreamSuite extends KafkaStreamSuiteBase with BeforeAndAfter if (ssc != null) { ssc.stop() } - if (tempDirectory != null && tempDirectory.exists()) { - FileUtils.deleteDirectory(tempDirectory) - tempDirectory = null - } + Utils.deleteRecursively(tempDirectory) tearDownKafka() } diff --git a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala index b61d9f0fbe5e4..8d15150458d26 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala @@ -19,13 +19,12 @@ package org.apache.spark.graphx import org.scalatest.FunSuite -import com.google.common.io.Files - import org.apache.spark.SparkContext import org.apache.spark.graphx.Graph._ import org.apache.spark.graphx.PartitionStrategy._ import org.apache.spark.rdd._ import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.Utils class GraphSuite extends FunSuite with LocalSparkContext { @@ -369,8 +368,7 @@ class GraphSuite extends FunSuite with LocalSparkContext { } test("checkpoint") { - val checkpointDir = Files.createTempDir() - checkpointDir.deleteOnExit() + val checkpointDir = Utils.createTempDir() withSpark { sc => sc.setCheckpointDir(checkpointDir.getAbsolutePath) val ring = (0L to 100L).zip((1L to 99L) :+ 0L).map { case (a, b) => Edge(a, b, 1)} diff --git a/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala index fbef5b25ba688..14f5e9ed4f25e 100644 --- a/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala +++ b/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala @@ -21,11 +21,9 @@ import java.io._ import java.net.URLClassLoader import scala.collection.mutable.ArrayBuffer -import scala.concurrent.Await import scala.concurrent.duration._ import scala.tools.nsc.interpreter.SparkILoop -import com.google.common.io.Files import org.scalatest.FunSuite import org.apache.commons.lang3.StringEscapeUtils import org.apache.spark.SparkContext @@ -196,8 +194,7 @@ class ReplSuite extends FunSuite { } test("interacting with files") { - val tempDir = Files.createTempDir() - tempDir.deleteOnExit() + val tempDir = Utils.createTempDir() val out = new FileWriter(tempDir + "/input") out.write("Hello world!\n") out.write("What's up?\n") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala index 80c7dfd376c96..528e38a50a740 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.catalyst.rules -import org.apache.spark.sql.catalyst.util +import org.apache.spark.util.Utils /** * A collection of generators that build custom bytecode at runtime for performing the evaluation @@ -52,7 +52,7 @@ package object codegen { @DeveloperApi object DumpByteCode { import scala.sys.process._ - val dumpDirectory = util.getTempFilePath("sparkSqlByteCode") + val dumpDirectory = Utils.createTempDir() dumpDirectory.mkdir() def apply(obj: Any): Unit = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala index d8da45ae70c4b..feed50f9a2a2d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala @@ -19,20 +19,9 @@ package org.apache.spark.sql.catalyst import java.io.{PrintWriter, ByteArrayOutputStream, FileInputStream, File} -import org.apache.spark.util.{Utils => SparkUtils} +import org.apache.spark.util.Utils package object util { - /** - * Returns a path to a temporary file that probably does not exist. - * Note, there is always the race condition that someone created this - * file since the last time we checked. Thus, this shouldn't be used - * for anything security conscious. - */ - def getTempFilePath(prefix: String, suffix: String = ""): File = { - val tempFile = File.createTempFile(prefix, suffix) - tempFile.delete() - tempFile - } def fileToString(file: File, encoding: String = "UTF-8") = { val inStream = new FileInputStream(file) @@ -56,7 +45,7 @@ package object util { def resourceToString( resource:String, encoding: String = "UTF-8", - classLoader: ClassLoader = SparkUtils.getSparkClassLoader) = { + classLoader: ClassLoader = Utils.getSparkClassLoader) = { val inStream = classLoader.getResourceAsStream(resource) val outStream = new ByteArrayOutputStream try { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTest.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTest.scala index d6ea6679c5966..9d17516e0ef7d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTest.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTest.scala @@ -23,7 +23,6 @@ import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag import scala.util.Try -import org.apache.spark.sql.catalyst.util import org.apache.spark.sql.{DataFrame, SQLContext, SaveMode} import org.apache.spark.util.Utils @@ -67,8 +66,9 @@ private[sql] trait ParquetTest { * @todo Probably this method should be moved to a more general place */ protected def withTempPath(f: File => Unit): Unit = { - val file = util.getTempFilePath("parquetTest").getCanonicalFile - try f(file) finally if (file.exists()) Utils.deleteRecursively(file) + val path = Utils.createTempDir() + path.delete() + try f(path) finally Utils.deleteRecursively(path) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index 23f424c0bfc7c..fe618e0e8e767 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql import java.io.File +import org.apache.spark.util.Utils + import scala.beans.{BeanInfo, BeanProperty} import org.apache.spark.rdd.RDD @@ -98,13 +100,13 @@ class UserDefinedTypeSuite extends QueryTest { test("UDTs with Parquet") { - val tempDir = File.createTempFile("parquet", "test") + val tempDir = Utils.createTempDir() tempDir.delete() pointsRDD.saveAsParquetFile(tempDir.getCanonicalPath) } test("Repartition UDTs with Parquet") { - val tempDir = File.createTempFile("parquet", "test") + val tempDir = Utils.createTempDir() tempDir.delete() pointsRDD.repartition(1).saveAsParquetFile(tempDir.getCanonicalPath) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala index 320b80d80e997..706c966ee05f5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala @@ -22,7 +22,6 @@ import java.sql.{Date, Timestamp} import org.scalactic.Tolerance._ import org.apache.spark.sql.TestData._ -import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.functions._ import org.apache.spark.sql.json.JsonRDD.{compatibleType, enforceCorrectType} import org.apache.spark.sql.sources.LogicalRelation @@ -31,6 +30,7 @@ import org.apache.spark.sql.test.TestSQLContext._ import org.apache.spark.sql.test.TestSQLContext.implicits._ import org.apache.spark.sql.types._ import org.apache.spark.sql.{QueryTest, Row, SQLConf} +import org.apache.spark.util.Utils class JsonSuite extends QueryTest { import org.apache.spark.sql.json.TestJsonData._ @@ -554,8 +554,9 @@ class JsonSuite extends QueryTest { } test("jsonFile should be based on JSONRelation") { - val file = getTempFilePath("json") - val path = file.toString + val dir = Utils.createTempDir() + dir.delete() + val path = dir.getCanonicalPath sparkContext.parallelize(1 to 100).map(i => s"""{"a": 1, "b": "str$i"}""").saveAsTextFile(path) val jsonDF = jsonFile(path, 0.49) @@ -580,8 +581,9 @@ class JsonSuite extends QueryTest { } test("Loading a JSON dataset from a text file") { - val file = getTempFilePath("json") - val path = file.toString + val dir = Utils.createTempDir() + dir.delete() + val path = dir.getCanonicalPath primitiveFieldAndType.map(record => record.replaceAll("\n", " ")).saveAsTextFile(path) val jsonDF = jsonFile(path) @@ -611,8 +613,9 @@ class JsonSuite extends QueryTest { } test("Loading a JSON dataset from a text file with SQL") { - val file = getTempFilePath("json") - val path = file.toString + val dir = Utils.createTempDir() + dir.delete() + val path = dir.getCanonicalPath primitiveFieldAndType.map(record => record.replaceAll("\n", " ")).saveAsTextFile(path) sql( @@ -637,8 +640,9 @@ class JsonSuite extends QueryTest { } test("Applying schemas") { - val file = getTempFilePath("json") - val path = file.toString + val dir = Utils.createTempDir() + dir.delete() + val path = dir.getCanonicalPath primitiveFieldAndType.map(record => record.replaceAll("\n", " ")).saveAsTextFile(path) val schema = StructType( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala index 60355414a40fa..2975a7fee4c96 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala @@ -22,7 +22,6 @@ import java.io.File import org.apache.spark.sql.AnalysisException import org.scalatest.BeforeAndAfterAll -import org.apache.spark.sql.catalyst.util import org.apache.spark.util.Utils class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { @@ -32,7 +31,7 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { var path: File = null override def beforeAll(): Unit = { - path = util.getTempFilePath("jsonCTAS").getCanonicalFile + path = Utils.createTempDir() val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}""")) jsonRDD(rdd).registerTempTable("jt") } @@ -42,7 +41,7 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { } after { - if (path.exists()) Utils.deleteRecursively(path) + Utils.deleteRecursively(path) } test("CREATE TEMPORARY TABLE AS SELECT") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala index b5b16f9546691..80efe9728fbc2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala @@ -22,7 +22,6 @@ import java.io.File import org.scalatest.BeforeAndAfterAll import org.apache.spark.sql.{AnalysisException, Row} -import org.apache.spark.sql.catalyst.util import org.apache.spark.util.Utils class InsertSuite extends DataSourceTest with BeforeAndAfterAll { @@ -32,7 +31,7 @@ class InsertSuite extends DataSourceTest with BeforeAndAfterAll { var path: File = null override def beforeAll: Unit = { - path = util.getTempFilePath("jsonCTAS").getCanonicalFile + path = Utils.createTempDir() val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}""")) jsonRDD(rdd).registerTempTable("jt") sql( @@ -48,7 +47,7 @@ class InsertSuite extends DataSourceTest with BeforeAndAfterAll { override def afterAll: Unit = { dropTempTable("jsonTable") dropTempTable("jt") - if (path.exists()) Utils.deleteRecursively(path) + Utils.deleteRecursively(path) } test("Simple INSERT OVERWRITE a JSONRelation") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala index 607488ccfdd6a..43bc8eb2d11a7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala @@ -21,7 +21,6 @@ import java.io.File import org.scalatest.BeforeAndAfterAll -import org.apache.spark.sql.catalyst.util import org.apache.spark.sql.{SaveMode, SQLConf, DataFrame} import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -39,7 +38,8 @@ class SaveLoadSuite extends DataSourceTest with BeforeAndAfterAll { override def beforeAll(): Unit = { originalDefaultSource = conf.defaultDataSourceName - path = util.getTempFilePath("datasource").getCanonicalFile + path = Utils.createTempDir() + path.delete() val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}""")) df = jsonRDD(rdd) @@ -52,7 +52,7 @@ class SaveLoadSuite extends DataSourceTest with BeforeAndAfterAll { after { conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, originalDefaultSource) - if (path.exists()) Utils.deleteRecursively(path) + Utils.deleteRecursively(path) } def checkLoad(): Unit = { diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala index 8bca4b33b3ad1..75738fa22b572 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala @@ -29,7 +29,7 @@ import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.scalatest.{BeforeAndAfterAll, FunSuite} import org.apache.spark.Logging -import org.apache.spark.sql.catalyst.util.getTempFilePath +import org.apache.spark.util.Utils class CliSuite extends FunSuite with BeforeAndAfterAll with Logging { def runCliWithin( @@ -38,8 +38,10 @@ class CliSuite extends FunSuite with BeforeAndAfterAll with Logging { queriesAndExpectedAnswers: (String, String)*) { val (queries, expectedAnswers) = queriesAndExpectedAnswers.unzip - val warehousePath = getTempFilePath("warehouse") - val metastorePath = getTempFilePath("metastore") + val warehousePath = Utils.createTempDir() + warehousePath.delete() + val metastorePath = Utils.createTempDir() + metastorePath.delete() val cliScript = "../../bin/spark-sql".split("/").mkString(File.separator) val command = { diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala index aff96e21a5373..bf20acecb1f32 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala @@ -37,7 +37,6 @@ import org.apache.thrift.transport.TSocket import org.scalatest.{BeforeAndAfterAll, FunSuite} import org.apache.spark.Logging -import org.apache.spark.sql.catalyst.util import org.apache.spark.sql.hive.HiveShim import org.apache.spark.util.Utils @@ -447,8 +446,10 @@ abstract class HiveThriftServer2Test extends FunSuite with BeforeAndAfterAll wit } private def startThriftServer(port: Int, attempt: Int) = { - warehousePath = util.getTempFilePath("warehouse") - metastorePath = util.getTempFilePath("metastore") + warehousePath = Utils.createTempDir() + warehousePath.delete() + metastorePath = Utils.createTempDir() + metastorePath.delete() logPath = null logTailingProcess = null diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index 4859991e2351a..b4aee78046383 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -30,7 +30,6 @@ import org.apache.hadoop.hive.serde2.avro.AvroSerDe import org.apache.spark.sql.SQLConf import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.execution.CacheTableCommand import org.apache.spark.sql.hive._ import org.apache.spark.sql.hive.execution.HiveNativeCommand @@ -69,22 +68,19 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { hiveconf.set("hive.plan.serialization.format", "javaXML") - lazy val warehousePath = getTempFilePath("sparkHiveWarehouse").getCanonicalPath - lazy val metastorePath = getTempFilePath("sparkHiveMetastore").getCanonicalPath + lazy val warehousePath = Utils.createTempDir() + lazy val metastorePath = Utils.createTempDir() /** Sets up the system initially or after a RESET command */ protected def configure(): Unit = { + warehousePath.delete() + metastorePath.delete() setConf("javax.jdo.option.ConnectionURL", s"jdbc:derby:;databaseName=$metastorePath;create=true") - setConf("hive.metastore.warehouse.dir", warehousePath) - Utils.registerShutdownDeleteDir(new File(warehousePath)) - Utils.registerShutdownDeleteDir(new File(metastorePath)) + setConf("hive.metastore.warehouse.dir", warehousePath.toString) } - val testTempDir = File.createTempFile("testTempFiles", "spark.hive.tmp") - testTempDir.delete() - testTempDir.mkdir() - Utils.registerShutdownDeleteDir(testTempDir) + val testTempDir = Utils.createTempDir() // For some hive test case which contain ${system:test.tmp.dir} System.setProperty("test.tmp.dir", testTempDir.getCanonicalPath) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala index d4b175fa443a4..381cd2a29123e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala @@ -21,12 +21,11 @@ import java.io.File import org.scalatest.BeforeAndAfter -import com.google.common.io.Files - import org.apache.spark.sql.execution.QueryExecutionException import org.apache.spark.sql.{QueryTest, _} import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils /* Implicits */ import org.apache.spark.sql.hive.test.TestHive._ @@ -112,7 +111,7 @@ class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter { test("SPARK-4203:random partition directory order") { sql("CREATE TABLE tmp_table (key int, value string)") - val tmpDir = Files.createTempDir() + val tmpDir = Utils.createTempDir() sql(s"CREATE TABLE table_with_partition(c1 string) PARTITIONED by (p1 string,p2 string,p3 string,p4 string,p5 string) location '${tmpDir.toURI.toString}' ") sql("INSERT OVERWRITE TABLE table_with_partition partition (p1='a',p2='b',p3='c',p4='c',p5='1') SELECT 'blarr' FROM tmp_table") sql("INSERT OVERWRITE TABLE table_with_partition partition (p1='a',p2='b',p3='c',p4='c',p5='2') SELECT 'blarr' FROM tmp_table") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index 5d6a6f3b64f03..ff2e6ea9ea51d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -19,13 +19,14 @@ package org.apache.spark.sql.hive import java.io.File +import scala.collection.mutable.ArrayBuffer + import org.scalatest.BeforeAndAfterEach import org.apache.commons.io.FileUtils import org.apache.hadoop.fs.Path import org.apache.hadoop.mapred.InvalidInputException -import org.apache.spark.sql.catalyst.util import org.apache.spark.sql._ import org.apache.spark.util.Utils import org.apache.spark.sql.types._ @@ -34,8 +35,6 @@ import org.apache.spark.sql.hive.test.TestHive.implicits._ import org.apache.spark.sql.parquet.ParquetRelation2 import org.apache.spark.sql.sources.LogicalRelation -import scala.collection.mutable.ArrayBuffer - /** * Tests for persisting tables created though the data sources API into the metastore. */ @@ -43,11 +42,12 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach { override def afterEach(): Unit = { reset() - if (tempPath.exists()) Utils.deleteRecursively(tempPath) + Utils.deleteRecursively(tempPath) } val filePath = Utils.getSparkClassLoader.getResource("sample.json").getFile - var tempPath: File = util.getTempFilePath("jsonCTAS").getCanonicalFile + var tempPath: File = Utils.createTempDir() + tempPath.delete() test ("persistent JSON table") { sql( @@ -154,7 +154,7 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach { } test("check change without refresh") { - val tempDir = File.createTempFile("sparksql", "json") + val tempDir = File.createTempFile("sparksql", "json", Utils.createTempDir()) tempDir.delete() sparkContext.parallelize(("a", "b") :: Nil).toDF() .toJSON.saveAsTextFile(tempDir.getCanonicalPath) @@ -192,7 +192,7 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach { } test("drop, change, recreate") { - val tempDir = File.createTempFile("sparksql", "json") + val tempDir = File.createTempFile("sparksql", "json", Utils.createTempDir()) tempDir.delete() sparkContext.parallelize(("a", "b") :: Nil).toDF() .toJSON.saveAsTextFile(tempDir.getCanonicalPath) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala index 1904f5faef3a0..d891c4e8903d9 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.sources.{InsertIntoDataSource, LogicalRelation} import org.apache.spark.sql.parquet.{ParquetRelation2, ParquetTableScan} import org.apache.spark.sql.SaveMode import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils // The data where the partitioning key exists only in the directory structure. case class ParquetData(intField: Int, stringField: String) @@ -579,13 +580,8 @@ abstract class ParquetPartitioningTest extends QueryTest with BeforeAndAfterAll var partitionedTableDirWithKeyAndComplexTypes: File = null override def beforeAll(): Unit = { - partitionedTableDir = File.createTempFile("parquettests", "sparksql") - partitionedTableDir.delete() - partitionedTableDir.mkdir() - - normalTableDir = File.createTempFile("parquettests", "sparksql") - normalTableDir.delete() - normalTableDir.mkdir() + partitionedTableDir = Utils.createTempDir() + normalTableDir = Utils.createTempDir() (1 to 10).foreach { p => val partDir = new File(partitionedTableDir, s"p=$p") @@ -601,9 +597,7 @@ abstract class ParquetPartitioningTest extends QueryTest with BeforeAndAfterAll .toDF() .saveAsParquetFile(new File(normalTableDir, "normal").getCanonicalPath) - partitionedTableDirWithKey = File.createTempFile("parquettests", "sparksql") - partitionedTableDirWithKey.delete() - partitionedTableDirWithKey.mkdir() + partitionedTableDirWithKey = Utils.createTempDir() (1 to 10).foreach { p => val partDir = new File(partitionedTableDirWithKey, s"p=$p") @@ -613,9 +607,7 @@ abstract class ParquetPartitioningTest extends QueryTest with BeforeAndAfterAll .saveAsParquetFile(partDir.getCanonicalPath) } - partitionedTableDirWithKeyAndComplexTypes = File.createTempFile("parquettests", "sparksql") - partitionedTableDirWithKeyAndComplexTypes.delete() - partitionedTableDirWithKeyAndComplexTypes.mkdir() + partitionedTableDirWithKeyAndComplexTypes = Utils.createTempDir() (1 to 10).foreach { p => val partDir = new File(partitionedTableDirWithKeyAndComplexTypes, s"p=$p") @@ -625,9 +617,7 @@ abstract class ParquetPartitioningTest extends QueryTest with BeforeAndAfterAll }.toDF().saveAsParquetFile(partDir.getCanonicalPath) } - partitionedTableDirWithComplexTypes = File.createTempFile("parquettests", "sparksql") - partitionedTableDirWithComplexTypes.delete() - partitionedTableDirWithComplexTypes.mkdir() + partitionedTableDirWithComplexTypes = Utils.createTempDir() (1 to 10).foreach { p => val partDir = new File(partitionedTableDirWithComplexTypes, s"p=$p") diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala index 8ea91eca683cf..91a2b2bba461d 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala @@ -222,7 +222,7 @@ class CheckpointSuite extends TestSuiteBase { } test("recovery with saveAsHadoopFiles operation") { - val tempDir = Files.createTempDir() + val tempDir = Utils.createTempDir() try { testCheckpointedOperation( Seq(Seq("a", "a", "b"), Seq("", ""), Seq(), Seq("a", "a", "b"), Seq("", ""), Seq()), @@ -245,7 +245,7 @@ class CheckpointSuite extends TestSuiteBase { } test("recovery with saveAsNewAPIHadoopFiles operation") { - val tempDir = Files.createTempDir() + val tempDir = Utils.createTempDir() try { testCheckpointedOperation( Seq(Seq("a", "a", "b"), Seq("", ""), Seq(), Seq("a", "a", "b"), Seq("", ""), Seq()), @@ -283,7 +283,7 @@ class CheckpointSuite extends TestSuiteBase { // // After SPARK-5079 is addressed, should be able to remove this test since a strengthened // version of the other saveAsHadoopFile* tests would prevent regressions for this issue. - val tempDir = Files.createTempDir() + val tempDir = Utils.createTempDir() try { testCheckpointedOperation( Seq(Seq("a", "a", "b"), Seq("", ""), Seq(), Seq("a", "a", "b"), Seq("", ""), Seq()), diff --git a/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala index 6500608bba87c..26435d8515815 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala @@ -20,15 +20,13 @@ package org.apache.spark.streaming import org.apache.spark.Logging import org.apache.spark.util.Utils -import java.io.File - /** * This testsuite tests master failures at random times while the stream is running using * the real clock. */ class FailureSuite extends TestSuiteBase with Logging { - val directory = Utils.createTempDir().getAbsolutePath + val directory = Utils.createTempDir() val numBatches = 30 override def batchDuration = Milliseconds(1000) @@ -36,16 +34,16 @@ class FailureSuite extends TestSuiteBase with Logging { override def useManualClock = false override def afterFunction() { - Utils.deleteRecursively(new File(directory)) + Utils.deleteRecursively(directory) super.afterFunction() } test("multiple failures with map") { - MasterFailureTest.testMap(directory, numBatches, batchDuration) + MasterFailureTest.testMap(directory.getAbsolutePath, numBatches, batchDuration) } test("multiple failures with updateStateByKey") { - MasterFailureTest.testUpdateStateByKey(directory, numBatches, batchDuration) + MasterFailureTest.testUpdateStateByKey(directory.getAbsolutePath, numBatches, batchDuration) } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala index 818f551dbe996..18a477f92094d 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala @@ -25,8 +25,6 @@ import scala.concurrent.duration._ import scala.language.postfixOps import akka.actor.{ActorSystem, Props} -import com.google.common.io.Files -import org.apache.commons.io.FileUtils import org.apache.hadoop.conf.Configuration import org.scalatest.{BeforeAndAfter, FunSuite, Matchers} import org.scalatest.concurrent.Eventually._ @@ -39,7 +37,7 @@ import org.apache.spark.shuffle.hash.HashShuffleManager import org.apache.spark.storage._ import org.apache.spark.streaming.receiver._ import org.apache.spark.streaming.util._ -import org.apache.spark.util.{AkkaUtils, ManualClock} +import org.apache.spark.util.{AkkaUtils, ManualClock, Utils} import WriteAheadLogBasedBlockHandler._ import WriteAheadLogSuite._ @@ -76,7 +74,7 @@ class ReceivedBlockHandlerSuite extends FunSuite with BeforeAndAfter with Matche new NioBlockTransferService(conf, securityMgr), securityMgr, 0) blockManager.initialize("app-id") - tempDirectory = Files.createTempDir() + tempDirectory = Utils.createTempDir() manualClock.setTime(0) } @@ -93,10 +91,7 @@ class ReceivedBlockHandlerSuite extends FunSuite with BeforeAndAfter with Matche actorSystem.awaitTermination() actorSystem = null - if (tempDirectory != null && tempDirectory.exists()) { - FileUtils.deleteDirectory(tempDirectory) - tempDirectory = null - } + Utils.deleteRecursively(tempDirectory) } test("BlockManagerBasedBlockHandler - store blocks") { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala index a3a0fd5187403..42fad769f0c1a 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala @@ -24,8 +24,6 @@ import scala.concurrent.duration._ import scala.language.{implicitConversions, postfixOps} import scala.util.Random -import com.google.common.io.Files -import org.apache.commons.io.FileUtils import org.apache.hadoop.conf.Configuration import org.scalatest.{BeforeAndAfter, FunSuite, Matchers} import org.scalatest.concurrent.Eventually._ @@ -51,15 +49,12 @@ class ReceivedBlockTrackerSuite before { conf = new SparkConf().setMaster("local[2]").setAppName("ReceivedBlockTrackerSuite") - checkpointDirectory = Files.createTempDir() + checkpointDirectory = Utils.createTempDir() } after { allReceivedBlockTrackers.foreach { _.stop() } - if (checkpointDirectory != null && checkpointDirectory.exists()) { - FileUtils.deleteDirectory(checkpointDirectory) - checkpointDirectory = null - } + Utils.deleteRecursively(checkpointDirectory) } test("block addition, and block to batch allocation") { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala index e8c34a9ee40b9..aa20ad0b5374e 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala @@ -24,7 +24,6 @@ import java.util.concurrent.Semaphore import scala.collection.mutable import scala.collection.mutable.ArrayBuffer -import com.google.common.io.Files import org.scalatest.concurrent.Timeouts import org.scalatest.concurrent.Eventually._ import org.scalatest.time.SpanSugar._ @@ -34,6 +33,7 @@ import org.apache.spark.storage.StorageLevel import org.apache.spark.storage.StreamBlockId import org.apache.spark.streaming.receiver._ import org.apache.spark.streaming.receiver.WriteAheadLogBasedBlockHandler._ +import org.apache.spark.util.Utils /** Testsuite for testing the network receiver behavior */ class ReceiverSuite extends TestSuiteBase with Timeouts with Serializable { @@ -222,7 +222,7 @@ class ReceiverSuite extends TestSuiteBase with Timeouts with Serializable { .set("spark.streaming.receiver.writeAheadLog.enable", "true") .set("spark.streaming.receiver.writeAheadLog.rollingInterval", "1") val batchDuration = Milliseconds(500) - val tempDirectory = Files.createTempDir() + val tempDirectory = Utils.createTempDir() val logDirectory1 = new File(checkpointDirToLogDir(tempDirectory.getAbsolutePath, 0)) val logDirectory2 = new File(checkpointDirToLogDir(tempDirectory.getAbsolutePath, 1)) val allLogFiles1 = new mutable.HashSet[String]() @@ -251,7 +251,6 @@ class ReceiverSuite extends TestSuiteBase with Timeouts with Serializable { } withStreamingContext(new StreamingContext(sparkConf, batchDuration)) { ssc => - tempDirectory.deleteOnExit() val receiver1 = ssc.sparkContext.clean(new FakeReceiver(sendData = true)) val receiver2 = ssc.sparkContext.clean(new FakeReceiver(sendData = true)) val receiverStream1 = ssc.receiverStream(receiver1) diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala index b5a2db8f6225c..4194f36499e66 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala @@ -50,7 +50,7 @@ class YarnSparkHadoopUtilSuite extends FunSuite with Matchers with Logging { if (hasBash) test(name)(fn) else ignore(name)(fn) bashTest("shell script escaping") { - val scriptFile = File.createTempFile("script.", ".sh") + val scriptFile = File.createTempFile("script.", ".sh", Utils.createTempDir()) val args = Array("arg1", "${arg.2}", "\"arg3\"", "'arg4'", "$arg5", "\\arg6") try { val argLine = args.map(a => YarnSparkHadoopUtil.escapeForShell(a)).mkString(" ") From db4d317ccfdd9bd1dc7e8beac54ebcc35966b7d5 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 20 Mar 2015 14:13:02 -0400 Subject: [PATCH 11/47] [SPARK-6428][MLlib] Added explicit type for public methods and implemented hashCode when equals is defined. I want to add a checker to turn public type checking on, since future pull requests can accidentally expose a non-public type. This is the first cleanup task. Author: Reynold Xin Closes #5102 from rxin/mllib-hashcode-publicmethodtypes and squashes the following commits: 617f19e [Reynold Xin] Fixed Scala compilation error. 52bc2d5 [Reynold Xin] [MLlib] Added explicit type for public methods and implemented hashCode when equals is defined. --- .../spark/examples/mllib/MovieLensALS.scala | 3 +- .../PowerIterationClusteringExample.scala | 4 +-- .../apache/spark/ml/feature/HashingTF.scala | 2 +- .../mllib/api/python/PythonMLLibAPI.scala | 18 ++++++---- .../mllib/classification/NaiveBayes.scala | 6 ++-- .../impl/GLMClassificationModel.scala | 2 +- .../spark/mllib/clustering/KMeans.scala | 2 +- .../mllib/evaluation/MultilabelMetrics.scala | 18 +++++----- .../apache/spark/mllib/linalg/Matrices.scala | 12 +++++-- .../apache/spark/mllib/linalg/Vectors.scala | 4 ++- .../linalg/distributed/BlockMatrix.scala | 10 +++++- .../mllib/random/RandomDataGenerator.scala | 4 +-- .../regression/impl/GLMRegressionModel.scala | 2 +- .../mllib/tree/configuration/Strategy.scala | 9 +++-- .../spark/mllib/tree/impurity/Entropy.scala | 2 +- .../spark/mllib/tree/impurity/Gini.scala | 2 +- .../spark/mllib/tree/impurity/Variance.scala | 2 +- .../mllib/tree/model/DecisionTreeModel.scala | 4 +-- .../tree/model/InformationGainStats.scala | 35 ++++++++++++------- .../apache/spark/mllib/tree/model/Node.scala | 6 ++-- .../spark/mllib/tree/model/Predict.scala | 6 +++- .../apache/spark/mllib/tree/model/Split.scala | 3 +- .../mllib/tree/model/treeEnsembleModels.scala | 6 ++-- 23 files changed, 101 insertions(+), 61 deletions(-) diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala index 91a0a860d6c71..1f4ca4fbe7778 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala @@ -175,7 +175,8 @@ object MovieLensALS { } /** Compute RMSE (Root Mean Squared Error). */ - def computeRmse(model: MatrixFactorizationModel, data: RDD[Rating], implicitPrefs: Boolean) = { + def computeRmse(model: MatrixFactorizationModel, data: RDD[Rating], implicitPrefs: Boolean) + : Double = { def mapPredictedRating(r: Double) = if (implicitPrefs) math.max(math.min(r, 1.0), 0.0) else r diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/PowerIterationClusteringExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/PowerIterationClusteringExample.scala index 91c9772744f18..9f22d40c15f3f 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/PowerIterationClusteringExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/PowerIterationClusteringExample.scala @@ -116,7 +116,7 @@ object PowerIterationClusteringExample { sc.stop() } - def generateCircle(radius: Double, n: Int) = { + def generateCircle(radius: Double, n: Int): Seq[(Double, Double)] = { Seq.tabulate(n) { i => val theta = 2.0 * math.Pi * i / n (radius * math.cos(theta), radius * math.sin(theta)) @@ -147,7 +147,7 @@ object PowerIterationClusteringExample { /** * Gaussian Similarity: http://en.wikipedia.org/wiki/Radial_basis_function_kernel */ - def gaussianSimilarity(p1: (Double, Double), p2: (Double, Double), sigma: Double) = { + def gaussianSimilarity(p1: (Double, Double), p2: (Double, Double), sigma: Double): Double = { val coeff = 1.0 / (math.sqrt(2.0 * math.Pi) * sigma) val expCoeff = -1.0 / 2.0 * math.pow(sigma, 2.0) val ssquares = (p1._1 - p2._1) * (p1._1 - p2._1) + (p1._2 - p2._2) * (p1._2 - p2._2) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala index 6131ba8832691..fc4e12773c46d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala @@ -41,7 +41,7 @@ class HashingTF extends UnaryTransformer[Iterable[_], Vector, HashingTF] { def getNumFeatures: Int = get(numFeatures) /** @group setParam */ - def setNumFeatures(value: Int) = set(numFeatures, value) + def setNumFeatures(value: Int): this.type = set(numFeatures, value) override protected def createTransformFunc(paramMap: ParamMap): Iterable[_] => Vector = { val hashingTF = new feature.HashingTF(paramMap(numFeatures)) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index cbd87ea8aeb37..15ca2547d56a8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -345,9 +345,13 @@ private[python] class PythonMLLibAPI extends Serializable { def predict(userAndProducts: JavaRDD[Array[Any]]): RDD[Rating] = predict(SerDe.asTupleRDD(userAndProducts.rdd)) - def getUserFeatures = SerDe.fromTuple2RDD(userFeatures.asInstanceOf[RDD[(Any, Any)]]) + def getUserFeatures: RDD[Array[Any]] = { + SerDe.fromTuple2RDD(userFeatures.asInstanceOf[RDD[(Any, Any)]]) + } - def getProductFeatures = SerDe.fromTuple2RDD(productFeatures.asInstanceOf[RDD[(Any, Any)]]) + def getProductFeatures: RDD[Array[Any]] = { + SerDe.fromTuple2RDD(productFeatures.asInstanceOf[RDD[(Any, Any)]]) + } } @@ -909,7 +913,7 @@ private[spark] object SerDe extends Serializable { // Pickler for DenseVector private[python] class DenseVectorPickler extends BasePickler[DenseVector] { - def saveState(obj: Object, out: OutputStream, pickler: Pickler) = { + def saveState(obj: Object, out: OutputStream, pickler: Pickler): Unit = { val vector: DenseVector = obj.asInstanceOf[DenseVector] val bytes = new Array[Byte](8 * vector.size) val bb = ByteBuffer.wrap(bytes) @@ -941,7 +945,7 @@ private[spark] object SerDe extends Serializable { // Pickler for DenseMatrix private[python] class DenseMatrixPickler extends BasePickler[DenseMatrix] { - def saveState(obj: Object, out: OutputStream, pickler: Pickler) = { + def saveState(obj: Object, out: OutputStream, pickler: Pickler): Unit = { val m: DenseMatrix = obj.asInstanceOf[DenseMatrix] val bytes = new Array[Byte](8 * m.values.size) val order = ByteOrder.nativeOrder() @@ -973,7 +977,7 @@ private[spark] object SerDe extends Serializable { // Pickler for SparseVector private[python] class SparseVectorPickler extends BasePickler[SparseVector] { - def saveState(obj: Object, out: OutputStream, pickler: Pickler) = { + def saveState(obj: Object, out: OutputStream, pickler: Pickler): Unit = { val v: SparseVector = obj.asInstanceOf[SparseVector] val n = v.indices.size val indiceBytes = new Array[Byte](4 * n) @@ -1015,7 +1019,7 @@ private[spark] object SerDe extends Serializable { // Pickler for LabeledPoint private[python] class LabeledPointPickler extends BasePickler[LabeledPoint] { - def saveState(obj: Object, out: OutputStream, pickler: Pickler) = { + def saveState(obj: Object, out: OutputStream, pickler: Pickler): Unit = { val point: LabeledPoint = obj.asInstanceOf[LabeledPoint] saveObjects(out, pickler, point.label, point.features) } @@ -1031,7 +1035,7 @@ private[spark] object SerDe extends Serializable { // Pickler for Rating private[python] class RatingPickler extends BasePickler[Rating] { - def saveState(obj: Object, out: OutputStream, pickler: Pickler) = { + def saveState(obj: Object, out: OutputStream, pickler: Pickler): Unit = { val rating: Rating = obj.asInstanceOf[Rating] saveObjects(out, pickler, rating.user, rating.product, rating.rating) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala index 2ebc7fa5d4234..068449aa1d346 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -83,10 +83,10 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] { private object SaveLoadV1_0 { - def thisFormatVersion = "1.0" + def thisFormatVersion: String = "1.0" /** Hard-code class name string in case it changes in the future */ - def thisClassName = "org.apache.spark.mllib.classification.NaiveBayesModel" + def thisClassName: String = "org.apache.spark.mllib.classification.NaiveBayesModel" /** Model data for model import/export */ case class Data(labels: Array[Double], pi: Array[Double], theta: Array[Array[Double]]) @@ -174,7 +174,7 @@ class NaiveBayes private (private var lambda: Double) extends Serializable with * * @param data RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. */ - def run(data: RDD[LabeledPoint]) = { + def run(data: RDD[LabeledPoint]): NaiveBayesModel = { val requireNonnegativeValues: Vector => Unit = (v: Vector) => { val values = v match { case SparseVector(size, indices, values) => diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala index 8956189ff1158..3b6790cce47c6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala @@ -32,7 +32,7 @@ private[classification] object GLMClassificationModel { object SaveLoadV1_0 { - def thisFormatVersion = "1.0" + def thisFormatVersion: String = "1.0" /** Model data for import/export */ case class Data(weights: Vector, intercept: Double, threshold: Option[Double]) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index e41f941fd2c2c..0f8d6a399682d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -536,5 +536,5 @@ class VectorWithNorm(val vector: Vector, val norm: Double) extends Serializable def this(array: Array[Double]) = this(Vectors.dense(array)) /** Converts the vector to a dense vector. */ - def toDense = new VectorWithNorm(Vectors.dense(vector.toArray), norm) + def toDense: VectorWithNorm = new VectorWithNorm(Vectors.dense(vector.toArray), norm) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala index ea10bde5fa252..a8378a76d20ae 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala @@ -96,30 +96,30 @@ class MultilabelMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])] * Returns precision for a given label (category) * @param label the label. */ - def precision(label: Double) = { + def precision(label: Double): Double = { val tp = tpPerClass(label) val fp = fpPerClass.getOrElse(label, 0L) - if (tp + fp == 0) 0 else tp.toDouble / (tp + fp) + if (tp + fp == 0) 0.0 else tp.toDouble / (tp + fp) } /** * Returns recall for a given label (category) * @param label the label. */ - def recall(label: Double) = { + def recall(label: Double): Double = { val tp = tpPerClass(label) val fn = fnPerClass.getOrElse(label, 0L) - if (tp + fn == 0) 0 else tp.toDouble / (tp + fn) + if (tp + fn == 0) 0.0 else tp.toDouble / (tp + fn) } /** * Returns f1-measure for a given label (category) * @param label the label. */ - def f1Measure(label: Double) = { + def f1Measure(label: Double): Double = { val p = precision(label) val r = recall(label) - if((p + r) == 0) 0 else 2 * p * r / (p + r) + if((p + r) == 0) 0.0 else 2 * p * r / (p + r) } private lazy val sumTp = tpPerClass.foldLeft(0L) { case (sum, (_, tp)) => sum + tp } @@ -130,7 +130,7 @@ class MultilabelMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])] * Returns micro-averaged label-based precision * (equals to micro-averaged document-based precision) */ - lazy val microPrecision = { + lazy val microPrecision: Double = { val sumFp = fpPerClass.foldLeft(0L){ case(cum, (_, fp)) => cum + fp} sumTp.toDouble / (sumTp + sumFp) } @@ -139,7 +139,7 @@ class MultilabelMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])] * Returns micro-averaged label-based recall * (equals to micro-averaged document-based recall) */ - lazy val microRecall = { + lazy val microRecall: Double = { val sumFn = fnPerClass.foldLeft(0.0){ case(cum, (_, fn)) => cum + fn} sumTp.toDouble / (sumTp + sumFn) } @@ -148,7 +148,7 @@ class MultilabelMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])] * Returns micro-averaged label-based f1-measure * (equals to micro-averaged document-based f1-measure) */ - lazy val microF1Measure = 2.0 * sumTp / (2 * sumTp + sumFnClass + sumFpClass) + lazy val microF1Measure: Double = 2.0 * sumTp / (2 * sumTp + sumFnClass + sumFpClass) /** * Returns the sequence of labels in ascending order diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala index 0e4a4d0085895..fdd8848189f19 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala @@ -146,12 +146,16 @@ class DenseMatrix( def this(numRows: Int, numCols: Int, values: Array[Double]) = this(numRows, numCols, values, false) - override def equals(o: Any) = o match { + override def equals(o: Any): Boolean = o match { case m: DenseMatrix => m.numRows == numRows && m.numCols == numCols && Arrays.equals(toArray, m.toArray) case _ => false } + override def hashCode: Int = { + com.google.common.base.Objects.hashCode(numRows : Integer, numCols: Integer, toArray) + } + private[mllib] def toBreeze: BM[Double] = { if (!isTransposed) { new BDM[Double](numRows, numCols, values) @@ -173,7 +177,7 @@ class DenseMatrix( values(index(i, j)) = v } - override def copy = new DenseMatrix(numRows, numCols, values.clone()) + override def copy: DenseMatrix = new DenseMatrix(numRows, numCols, values.clone()) private[mllib] def map(f: Double => Double) = new DenseMatrix(numRows, numCols, values.map(f)) @@ -431,7 +435,9 @@ class SparseMatrix( } } - override def copy = new SparseMatrix(numRows, numCols, colPtrs, rowIndices, values.clone()) + override def copy: SparseMatrix = { + new SparseMatrix(numRows, numCols, colPtrs, rowIndices, values.clone()) + } private[mllib] def map(f: Double => Double) = new SparseMatrix(numRows, numCols, colPtrs, rowIndices, values.map(f)) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index e9d25dcb7e778..2cda9b252ee06 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -183,6 +183,8 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] { } } + override def hashCode: Int = 7919 + private[spark] override def asNullable: VectorUDT = this } @@ -478,7 +480,7 @@ class DenseVector(val values: Array[Double]) extends Vector { private[mllib] override def toBreeze: BV[Double] = new BDV[Double](values) - override def apply(i: Int) = values(i) + override def apply(i: Int): Double = values(i) override def copy: DenseVector = { new DenseVector(values.clone()) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala index 1d253963130f1..3323ae7b1fba0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala @@ -49,7 +49,7 @@ private[mllib] class GridPartitioner( private val rowPartitions = math.ceil(rows * 1.0 / rowsPerPart).toInt private val colPartitions = math.ceil(cols * 1.0 / colsPerPart).toInt - override val numPartitions = rowPartitions * colPartitions + override val numPartitions: Int = rowPartitions * colPartitions /** * Returns the index of the partition the input coordinate belongs to. @@ -85,6 +85,14 @@ private[mllib] class GridPartitioner( false } } + + override def hashCode: Int = { + com.google.common.base.Objects.hashCode( + rows: java.lang.Integer, + cols: java.lang.Integer, + rowsPerPart: java.lang.Integer, + colsPerPart: java.lang.Integer) + } } private[mllib] object GridPartitioner { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala index 405bae62ee8b6..9349ecaa13f56 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala @@ -56,7 +56,7 @@ class UniformGenerator extends RandomDataGenerator[Double] { random.nextDouble() } - override def setSeed(seed: Long) = random.setSeed(seed) + override def setSeed(seed: Long): Unit = random.setSeed(seed) override def copy(): UniformGenerator = new UniformGenerator() } @@ -75,7 +75,7 @@ class StandardNormalGenerator extends RandomDataGenerator[Double] { random.nextGaussian() } - override def setSeed(seed: Long) = random.setSeed(seed) + override def setSeed(seed: Long): Unit = random.setSeed(seed) override def copy(): StandardNormalGenerator = new StandardNormalGenerator() } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala index bd7e340ca2d8e..b55944f74f623 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala @@ -32,7 +32,7 @@ private[regression] object GLMRegressionModel { object SaveLoadV1_0 { - def thisFormatVersion = "1.0" + def thisFormatVersion: String = "1.0" /** Model data for model import/export */ case class Data(weights: Vector, intercept: Double) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index 8d5c36da32bdb..ada227c200a79 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -83,10 +83,13 @@ class Strategy ( @BeanProperty var useNodeIdCache: Boolean = false, @BeanProperty var checkpointInterval: Int = 10) extends Serializable { - def isMulticlassClassification = + def isMulticlassClassification: Boolean = { algo == Classification && numClasses > 2 - def isMulticlassWithCategoricalFeatures - = isMulticlassClassification && (categoricalFeaturesInfo.size > 0) + } + + def isMulticlassWithCategoricalFeatures: Boolean = { + isMulticlassClassification && (categoricalFeaturesInfo.size > 0) + } /** * Java-friendly constructor for [[org.apache.spark.mllib.tree.configuration.Strategy]] diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala index b7950e00786ab..5ac10f3fd32dd 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala @@ -71,7 +71,7 @@ object Entropy extends Impurity { * Get this impurity instance. * This is useful for passing impurity parameters to a Strategy in Java. */ - def instance = this + def instance: this.type = this } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala index c946db9c0d1c8..19d318203c344 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala @@ -67,7 +67,7 @@ object Gini extends Impurity { * Get this impurity instance. * This is useful for passing impurity parameters to a Strategy in Java. */ - def instance = this + def instance: this.type = this } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala index df9eafa5da16a..7104a7fa4dd4c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala @@ -58,7 +58,7 @@ object Variance extends Impurity { * Get this impurity instance. * This is useful for passing impurity parameters to a Strategy in Java. */ - def instance = this + def instance: this.type = this } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala index 8a57ebc387d01..c9bafd60fba4d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala @@ -120,10 +120,10 @@ object DecisionTreeModel extends Loader[DecisionTreeModel] with Logging { private[tree] object SaveLoadV1_0 { - def thisFormatVersion = "1.0" + def thisFormatVersion: String = "1.0" // Hard-code class name string in case it changes in the future - def thisClassName = "org.apache.spark.mllib.tree.DecisionTreeModel" + def thisClassName: String = "org.apache.spark.mllib.tree.DecisionTreeModel" case class PredictData(predict: Double, prob: Double) { def toPredict: Predict = new Predict(predict, prob) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala index 80990aa9a603f..f209fdafd3653 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala @@ -38,23 +38,32 @@ class InformationGainStats( val leftPredict: Predict, val rightPredict: Predict) extends Serializable { - override def toString = { + override def toString: String = { "gain = %f, impurity = %f, left impurity = %f, right impurity = %f" .format(gain, impurity, leftImpurity, rightImpurity) } - override def equals(o: Any) = - o match { - case other: InformationGainStats => { - gain == other.gain && - impurity == other.impurity && - leftImpurity == other.leftImpurity && - rightImpurity == other.rightImpurity && - leftPredict == other.leftPredict && - rightPredict == other.rightPredict - } - case _ => false - } + override def equals(o: Any): Boolean = o match { + case other: InformationGainStats => + gain == other.gain && + impurity == other.impurity && + leftImpurity == other.leftImpurity && + rightImpurity == other.rightImpurity && + leftPredict == other.leftPredict && + rightPredict == other.rightPredict + + case _ => false + } + + override def hashCode: Int = { + com.google.common.base.Objects.hashCode( + gain: java.lang.Double, + impurity: java.lang.Double, + leftImpurity: java.lang.Double, + rightImpurity: java.lang.Double, + leftPredict, + rightPredict) + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala index d961081d185e9..4f72bb8014cc0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala @@ -50,8 +50,10 @@ class Node ( var rightNode: Option[Node], var stats: Option[InformationGainStats]) extends Serializable with Logging { - override def toString = "id = " + id + ", isLeaf = " + isLeaf + ", predict = " + predict + ", " + - "impurity = " + impurity + "split = " + split + ", stats = " + stats + override def toString: String = { + "id = " + id + ", isLeaf = " + isLeaf + ", predict = " + predict + ", " + + "impurity = " + impurity + "split = " + split + ", stats = " + stats + } /** * build the left node and right nodes if not leaf diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala index ad4c0dbbfb3e5..25990af7c6cf7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala @@ -29,7 +29,7 @@ class Predict( val predict: Double, val prob: Double = 0.0) extends Serializable { - override def toString = { + override def toString: String = { "predict = %f, prob = %f".format(predict, prob) } @@ -39,4 +39,8 @@ class Predict( case _ => false } } + + override def hashCode: Int = { + com.google.common.base.Objects.hashCode(predict: java.lang.Double, prob: java.lang.Double) + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala index b7a85f58544a3..fb35e70a8d077 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala @@ -38,9 +38,10 @@ case class Split( featureType: FeatureType, categories: List[Double]) { - override def toString = + override def toString: String = { "Feature = " + feature + ", threshold = " + threshold + ", featureType = " + featureType + ", categories = " + categories + } } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala index 30a8f7ca301af..f160852c69c77 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala @@ -79,7 +79,7 @@ object RandomForestModel extends Loader[RandomForestModel] { private object SaveLoadV1_0 { // Hard-code class name string in case it changes in the future - def thisClassName = "org.apache.spark.mllib.tree.model.RandomForestModel" + def thisClassName: String = "org.apache.spark.mllib.tree.model.RandomForestModel" } } @@ -130,7 +130,7 @@ object GradientBoostedTreesModel extends Loader[GradientBoostedTreesModel] { private object SaveLoadV1_0 { // Hard-code class name string in case it changes in the future - def thisClassName = "org.apache.spark.mllib.tree.model.GradientBoostedTreesModel" + def thisClassName: String = "org.apache.spark.mllib.tree.model.GradientBoostedTreesModel" } } @@ -257,7 +257,7 @@ private[tree] object TreeEnsembleModel extends Logging { import org.apache.spark.mllib.tree.model.DecisionTreeModel.SaveLoadV1_0.{NodeData, constructTrees} - def thisFormatVersion = "1.0" + def thisFormatVersion: String = "1.0" case class Metadata( algo: String, From 28bcb9e9e86a4b643fbf96b2b7e03928ebcfc060 Mon Sep 17 00:00:00 2001 From: mbonaci Date: Fri, 20 Mar 2015 18:30:45 +0000 Subject: [PATCH 12/47] [SPARK-6370][core] Documentation: Improve all 3 docs for RDD.sample The docs for the `sample` method were insufficient, now less so. Author: mbonaci Closes #5097 from mbonaci/master and squashes the following commits: a6a9d97 [mbonaci] [SPARK-6370][core] Documentation: Improve all 3 docs for RDD.sample method --- .../scala/org/apache/spark/api/java/JavaRDD.scala | 11 +++++++++++ core/src/main/scala/org/apache/spark/rdd/RDD.scala | 6 ++++++ python/pyspark/rdd.py | 6 ++++++ 3 files changed, 23 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala index 645dc3bfb6b06..3e9beb670f7ad 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala @@ -101,12 +101,23 @@ class JavaRDD[T](val rdd: RDD[T])(implicit val classTag: ClassTag[T]) /** * Return a sampled subset of this RDD. + * + * @param withReplacement can elements be sampled multiple times (replaced when sampled out) + * @param fraction expected size of the sample as a fraction of this RDD's size + * without replacement: probability that each element is chosen; fraction must be [0, 1] + * with replacement: expected number of times each element is chosen; fraction must be >= 0 */ def sample(withReplacement: Boolean, fraction: Double): JavaRDD[T] = sample(withReplacement, fraction, Utils.random.nextLong) /** * Return a sampled subset of this RDD. + * + * @param withReplacement can elements be sampled multiple times (replaced when sampled out) + * @param fraction expected size of the sample as a fraction of this RDD's size + * without replacement: probability that each element is chosen; fraction must be [0, 1] + * with replacement: expected number of times each element is chosen; fraction must be >= 0 + * @param seed seed for the random number generator */ def sample(withReplacement: Boolean, fraction: Double, seed: Long): JavaRDD[T] = wrapRDD(rdd.sample(withReplacement, fraction, seed)) diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index a139780d967e9..a4c74ed03e330 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -377,6 +377,12 @@ abstract class RDD[T: ClassTag]( /** * Return a sampled subset of this RDD. + * + * @param withReplacement can elements be sampled multiple times (replaced when sampled out) + * @param fraction expected size of the sample as a fraction of this RDD's size + * without replacement: probability that each element is chosen; fraction must be [0, 1] + * with replacement: expected number of times each element is chosen; fraction must be >= 0 + * @param seed seed for the random number generator */ def sample(withReplacement: Boolean, fraction: Double, diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index bf17f513c0bc3..c337a43c8a7fc 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -346,6 +346,12 @@ def sample(self, withReplacement, fraction, seed=None): """ Return a sampled subset of this RDD. + :param withReplacement: can elements be sampled multiple times (replaced when sampled out) + :param fraction: expected size of the sample as a fraction of this RDD's size + without replacement: probability that each element is chosen; fraction must be [0, 1] + with replacement: expected number of times each element is chosen; fraction must be >= 0 + :param seed: seed for the random number generator + >>> rdd = sc.parallelize(range(100), 4) >>> rdd.sample(False, 0.1, 81).count() 10 From 385b2ff10d9ef5083df49233f77c8e873561dc16 Mon Sep 17 00:00:00 2001 From: WangTaoTheTonic Date: Fri, 20 Mar 2015 18:42:18 +0000 Subject: [PATCH 13/47] [SPARK-6426][Doc]User could also point the yarn cluster config directory via YARN_CONF_DI... ...R https://issues.apache.org/jira/browse/SPARK-6426 Author: WangTaoTheTonic Closes #5103 from WangTaoTheTonic/SPARK-6426 and squashes the following commits: e6dd78d [WangTaoTheTonic] User could also point the yarn cluster config directory via YARN_CONF_DIR --- docs/submitting-applications.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/submitting-applications.md b/docs/submitting-applications.md index 57b074778f2b0..3ecbf2308cd44 100644 --- a/docs/submitting-applications.md +++ b/docs/submitting-applications.md @@ -133,10 +133,10 @@ The master URL passed to Spark can be in one of the following formats: Or, for a Mesos cluster using ZooKeeper, use mesos://zk://.... yarn-client Connect to a YARN cluster in -client mode. The cluster location will be found based on the HADOOP_CONF_DIR variable. +client mode. The cluster location will be found based on the HADOOP_CONF_DIR or YARN_CONF_DIR variable. yarn-cluster Connect to a YARN cluster in -cluster mode. The cluster location will be found based on HADOOP_CONF_DIR. +cluster mode. The cluster location will be found based on the HADOOP_CONF_DIR or YARN_CONF_DIR variable. From a74564591f1c824f9eed516ae79e079b355fd32b Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Fri, 20 Mar 2015 18:43:57 +0000 Subject: [PATCH 14/47] [SPARK-6371] [build] Update version to 1.4.0-SNAPSHOT. Author: Marcelo Vanzin Closes #5056 from vanzin/SPARK-6371 and squashes the following commits: 63220df [Marcelo Vanzin] Merge branch 'master' into SPARK-6371 6506f75 [Marcelo Vanzin] Use more fine-grained exclusion. 178ba71 [Marcelo Vanzin] Oops. 75b2375 [Marcelo Vanzin] Exclude VertexRDD in MiMA. a45a62c [Marcelo Vanzin] Work around MIMA warning. 1d8a670 [Marcelo Vanzin] Re-group jetty exclusion. 0e8e909 [Marcelo Vanzin] Ignore ml, don't ignore graphx. cef4603 [Marcelo Vanzin] Indentation. 296cf82 [Marcelo Vanzin] [SPARK-6371] [build] Update version to 1.4.0-SNAPSHOT. --- assembly/pom.xml | 2 +- bagel/pom.xml | 2 +- core/pom.xml | 2 +- core/src/main/scala/org/apache/spark/package.scala | 2 +- docs/_config.yml | 4 ++-- examples/pom.xml | 2 +- external/flume-sink/pom.xml | 2 +- external/flume/pom.xml | 2 +- external/kafka-assembly/pom.xml | 2 +- external/kafka/pom.xml | 2 +- external/mqtt/pom.xml | 2 +- external/twitter/pom.xml | 2 +- external/zeromq/pom.xml | 2 +- extras/java8-tests/pom.xml | 2 +- extras/kinesis-asl/pom.xml | 2 +- extras/spark-ganglia-lgpl/pom.xml | 2 +- graphx/pom.xml | 2 +- launcher/pom.xml | 2 +- mllib/pom.xml | 2 +- network/common/pom.xml | 2 +- network/shuffle/pom.xml | 2 +- network/yarn/pom.xml | 2 +- pom.xml | 2 +- project/MimaBuild.scala | 2 +- project/MimaExcludes.scala | 14 ++++++++++++++ repl/pom.xml | 2 +- sql/catalyst/pom.xml | 2 +- sql/core/pom.xml | 2 +- sql/hive-thriftserver/pom.xml | 2 +- sql/hive/pom.xml | 2 +- streaming/pom.xml | 2 +- tools/pom.xml | 2 +- yarn/pom.xml | 2 +- 33 files changed, 47 insertions(+), 33 deletions(-) diff --git a/assembly/pom.xml b/assembly/pom.xml index d3bb4bde0c412..f1f8b0d3682e2 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.3.0-SNAPSHOT + 1.4.0-SNAPSHOT ../pom.xml diff --git a/bagel/pom.xml b/bagel/pom.xml index 1fe61062b4606..1f3dec91314f2 100644 --- a/bagel/pom.xml +++ b/bagel/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.3.0-SNAPSHOT + 1.4.0-SNAPSHOT ../pom.xml diff --git a/core/pom.xml b/core/pom.xml index 81f8cba711df6..6cd1965ec37c2 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.3.0-SNAPSHOT + 1.4.0-SNAPSHOT ../pom.xml diff --git a/core/src/main/scala/org/apache/spark/package.scala b/core/src/main/scala/org/apache/spark/package.scala index b6249b492150a..2ab41ba488ff6 100644 --- a/core/src/main/scala/org/apache/spark/package.scala +++ b/core/src/main/scala/org/apache/spark/package.scala @@ -43,5 +43,5 @@ package org.apache package object spark { // For package docs only - val SPARK_VERSION = "1.3.0-SNAPSHOT" + val SPARK_VERSION = "1.4.0-SNAPSHOT" } diff --git a/docs/_config.yml b/docs/_config.yml index 0652927a8ce9b..b22b627f09007 100644 --- a/docs/_config.yml +++ b/docs/_config.yml @@ -14,8 +14,8 @@ include: # These allow the documentation to be updated with newer releases # of Spark, Scala, and Mesos. -SPARK_VERSION: 1.3.0-SNAPSHOT -SPARK_VERSION_SHORT: 1.3.0 +SPARK_VERSION: 1.4.0-SNAPSHOT +SPARK_VERSION_SHORT: 1.4.0 SCALA_BINARY_VERSION: "2.10" SCALA_VERSION: "2.10.4" MESOS_VERSION: 0.21.0 diff --git a/examples/pom.xml b/examples/pom.xml index 994071d94d0ad..7e93f0eec0b91 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.3.0-SNAPSHOT + 1.4.0-SNAPSHOT ../pom.xml diff --git a/external/flume-sink/pom.xml b/external/flume-sink/pom.xml index 96c2787e35cd0..67907bbfb6d1b 100644 --- a/external/flume-sink/pom.xml +++ b/external/flume-sink/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.3.0-SNAPSHOT + 1.4.0-SNAPSHOT ../../pom.xml diff --git a/external/flume/pom.xml b/external/flume/pom.xml index 172d447b77cda..8df7edbdcad33 100644 --- a/external/flume/pom.xml +++ b/external/flume/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.3.0-SNAPSHOT + 1.4.0-SNAPSHOT ../../pom.xml diff --git a/external/kafka-assembly/pom.xml b/external/kafka-assembly/pom.xml index 5109b8ed87524..0b79f47647f6b 100644 --- a/external/kafka-assembly/pom.xml +++ b/external/kafka-assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.3.0-SNAPSHOT + 1.4.0-SNAPSHOT ../../pom.xml diff --git a/external/kafka/pom.xml b/external/kafka/pom.xml index 369856187a244..f695cff410a18 100644 --- a/external/kafka/pom.xml +++ b/external/kafka/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.3.0-SNAPSHOT + 1.4.0-SNAPSHOT ../../pom.xml diff --git a/external/mqtt/pom.xml b/external/mqtt/pom.xml index a344f000c5002..98f95a9a64fa0 100644 --- a/external/mqtt/pom.xml +++ b/external/mqtt/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.3.0-SNAPSHOT + 1.4.0-SNAPSHOT ../../pom.xml diff --git a/external/twitter/pom.xml b/external/twitter/pom.xml index e95853f005ce2..8b6a8959ac4cf 100644 --- a/external/twitter/pom.xml +++ b/external/twitter/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.3.0-SNAPSHOT + 1.4.0-SNAPSHOT ../../pom.xml diff --git a/external/zeromq/pom.xml b/external/zeromq/pom.xml index 9b3475d7c3dc2..a50d378b34335 100644 --- a/external/zeromq/pom.xml +++ b/external/zeromq/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.3.0-SNAPSHOT + 1.4.0-SNAPSHOT ../../pom.xml diff --git a/extras/java8-tests/pom.xml b/extras/java8-tests/pom.xml index bc2f8be10c9ce..4351a8a12fe21 100644 --- a/extras/java8-tests/pom.xml +++ b/extras/java8-tests/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.10 - 1.3.0-SNAPSHOT + 1.4.0-SNAPSHOT ../../pom.xml diff --git a/extras/kinesis-asl/pom.xml b/extras/kinesis-asl/pom.xml index 7e49a71907336..25847a1b33d9c 100644 --- a/extras/kinesis-asl/pom.xml +++ b/extras/kinesis-asl/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.10 - 1.3.0-SNAPSHOT + 1.4.0-SNAPSHOT ../../pom.xml diff --git a/extras/spark-ganglia-lgpl/pom.xml b/extras/spark-ganglia-lgpl/pom.xml index 6eb29af03f833..e14bbae4a9b6e 100644 --- a/extras/spark-ganglia-lgpl/pom.xml +++ b/extras/spark-ganglia-lgpl/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.10 - 1.3.0-SNAPSHOT + 1.4.0-SNAPSHOT ../../pom.xml diff --git a/graphx/pom.xml b/graphx/pom.xml index c0d534e185d7f..d38a3aa8256b7 100644 --- a/graphx/pom.xml +++ b/graphx/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.3.0-SNAPSHOT + 1.4.0-SNAPSHOT ../pom.xml diff --git a/launcher/pom.xml b/launcher/pom.xml index ccbd9d0419a98..0fe2814135d88 100644 --- a/launcher/pom.xml +++ b/launcher/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.10 - 1.3.0-SNAPSHOT + 1.4.0-SNAPSHOT ../pom.xml diff --git a/mllib/pom.xml b/mllib/pom.xml index a76704a8c2c59..4c183543e3fa8 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.3.0-SNAPSHOT + 1.4.0-SNAPSHOT ../pom.xml diff --git a/network/common/pom.xml b/network/common/pom.xml index 74437f37c47e4..7b51845206f4a 100644 --- a/network/common/pom.xml +++ b/network/common/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.10 - 1.3.0-SNAPSHOT + 1.4.0-SNAPSHOT ../../pom.xml diff --git a/network/shuffle/pom.xml b/network/shuffle/pom.xml index a2bcca26d8344..7dc7c65825e34 100644 --- a/network/shuffle/pom.xml +++ b/network/shuffle/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.10 - 1.3.0-SNAPSHOT + 1.4.0-SNAPSHOT ../../pom.xml diff --git a/network/yarn/pom.xml b/network/yarn/pom.xml index cea7a20c223e2..1e2e9c80af6cc 100644 --- a/network/yarn/pom.xml +++ b/network/yarn/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.10 - 1.3.0-SNAPSHOT + 1.4.0-SNAPSHOT ../../pom.xml diff --git a/pom.xml b/pom.xml index efb9f172f4751..23bb16130b504 100644 --- a/pom.xml +++ b/pom.xml @@ -26,7 +26,7 @@ org.apache.spark spark-parent_2.10 - 1.3.0-SNAPSHOT + 1.4.0-SNAPSHOT pom Spark Project Parent POM http://spark.apache.org/ diff --git a/project/MimaBuild.scala b/project/MimaBuild.scala index f0cbf4e57b8c5..dde92949fa175 100644 --- a/project/MimaBuild.scala +++ b/project/MimaBuild.scala @@ -91,7 +91,7 @@ object MimaBuild { def mimaSettings(sparkHome: File, projectRef: ProjectRef) = { val organization = "org.apache.spark" - val previousSparkVersion = "1.2.0" + val previousSparkVersion = "1.3.0" val fullId = "spark-" + projectRef.project + "_2.10" mimaDefaultSettings ++ Seq(previousArtifact := Some(organization % fullId % previousSparkVersion), diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index a6b07fa7cddec..328d59485a731 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -16,6 +16,7 @@ */ import com.typesafe.tools.mima.core._ +import com.typesafe.tools.mima.core.ProblemFilters._ /** * Additional excludes for checking of Spark's binary compatibility. @@ -33,6 +34,19 @@ import com.typesafe.tools.mima.core._ object MimaExcludes { def excludes(version: String) = version match { + case v if v.startsWith("1.4") => + Seq( + MimaBuild.excludeSparkPackage("deploy"), + MimaBuild.excludeSparkPackage("ml"), + // SPARK-5922 Adding a generalized diff(other: RDD[(VertexId, VD)]) to VertexRDD + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.graphx.VertexRDD.diff"), + // These are needed if checking against the sbt build, since they are part of + // the maven-generated artifacts in 1.3. + excludePackage("org.spark-project.jetty"), + MimaBuild.excludeSparkPackage("unused"), + ProblemFilters.exclude[MissingClassProblem]("com.google.common.base.Optional") + ) + case v if v.startsWith("1.3") => Seq( MimaBuild.excludeSparkPackage("deploy"), diff --git a/repl/pom.xml b/repl/pom.xml index 295f88ea3ecf9..edfa1c7f2c29c 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.3.0-SNAPSHOT + 1.4.0-SNAPSHOT ../pom.xml diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index 8ad026dbdf8ff..3dea2ee76542f 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.10 - 1.3.0-SNAPSHOT + 1.4.0-SNAPSHOT ../../pom.xml diff --git a/sql/core/pom.xml b/sql/core/pom.xml index 3640104e497d4..e3a6b1fe72435 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.10 - 1.3.0-SNAPSHOT + 1.4.0-SNAPSHOT ../../pom.xml diff --git a/sql/hive-thriftserver/pom.xml b/sql/hive-thriftserver/pom.xml index f466a3c0b5dc2..a96b1ffc26966 100644 --- a/sql/hive-thriftserver/pom.xml +++ b/sql/hive-thriftserver/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.10 - 1.3.0-SNAPSHOT + 1.4.0-SNAPSHOT ../../pom.xml diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index 0e3f4eb98cbf7..a9816f6c38cd2 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.10 - 1.3.0-SNAPSHOT + 1.4.0-SNAPSHOT ../../pom.xml diff --git a/streaming/pom.xml b/streaming/pom.xml index 96508d83f4049..23a8358d45c2a 100644 --- a/streaming/pom.xml +++ b/streaming/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.3.0-SNAPSHOT + 1.4.0-SNAPSHOT ../pom.xml diff --git a/tools/pom.xml b/tools/pom.xml index 181236d1bcbf6..1c6f3e83a1819 100644 --- a/tools/pom.xml +++ b/tools/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.10 - 1.3.0-SNAPSHOT + 1.4.0-SNAPSHOT ../pom.xml diff --git a/yarn/pom.xml b/yarn/pom.xml index c13534f0410a1..7c8c3613e7a05 100644 --- a/yarn/pom.xml +++ b/yarn/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.10 - 1.3.0-SNAPSHOT + 1.4.0-SNAPSHOT ../pom.xml From 48866f789712b0cdbaf76054d1014c6df032fff1 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Fri, 20 Mar 2015 14:44:21 -0400 Subject: [PATCH 15/47] [SPARK-6095] [MLLIB] Support model save/load in Python's linear models For Python's linear models, weights and intercept are stored in Python. This PR implements Python's linear models sava/load functions which do the same thing as scala. It can also make model import/export cross languages. Author: Yanbo Liang Closes #5016 from yanboliang/spark-6095 and squashes the following commits: d9bb824 [Yanbo Liang] fix python style b3813ca [Yanbo Liang] linear model save/load for Python reuse the Scala implementation --- python/pyspark/mllib/classification.py | 58 +++++++++++++++++- python/pyspark/mllib/regression.py | 84 +++++++++++++++++++++++++- python/pyspark/mllib/util.py | 6 +- 3 files changed, 145 insertions(+), 3 deletions(-) diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py index e4765173709e8..b66159c5bfb66 100644 --- a/python/pyspark/mllib/classification.py +++ b/python/pyspark/mllib/classification.py @@ -21,7 +21,7 @@ from numpy import array from pyspark import RDD -from pyspark.mllib.common import callMLlibFunc +from pyspark.mllib.common import callMLlibFunc, _py2java, _java2py from pyspark.mllib.linalg import SparseVector, _convert_to_vector from pyspark.mllib.regression import LabeledPoint, LinearModel, _regression_train_wrapper @@ -99,6 +99,18 @@ class LogisticRegressionModel(LinearBinaryClassificationModel): 1 >>> lrm.predict(SparseVector(2, {0: 1.0})) 0 + >>> import os, tempfile + >>> path = tempfile.mkdtemp() + >>> lrm.save(sc, path) + >>> sameModel = LogisticRegressionModel.load(sc, path) + >>> sameModel.predict(array([0.0, 1.0])) + 1 + >>> sameModel.predict(SparseVector(2, {0: 1.0})) + 0 + >>> try: + ... os.removedirs(path) + ... except: + ... pass """ def __init__(self, weights, intercept): super(LogisticRegressionModel, self).__init__(weights, intercept) @@ -124,6 +136,22 @@ def predict(self, x): else: return 1 if prob > self._threshold else 0 + def save(self, sc, path): + java_model = sc._jvm.org.apache.spark.mllib.classification.LogisticRegressionModel( + _py2java(sc, self._coeff), self.intercept) + java_model.save(sc._jsc.sc(), path) + + @classmethod + def load(cls, sc, path): + java_model = sc._jvm.org.apache.spark.mllib.classification.LogisticRegressionModel.load( + sc._jsc.sc(), path) + weights = _java2py(sc, java_model.weights()) + intercept = java_model.intercept() + threshold = java_model.getThreshold().get() + model = LogisticRegressionModel(weights, intercept) + model.setThreshold(threshold) + return model + class LogisticRegressionWithSGD(object): @@ -243,6 +271,18 @@ class SVMModel(LinearBinaryClassificationModel): 1 >>> svm.predict(SparseVector(2, {0: -1.0})) 0 + >>> import os, tempfile + >>> path = tempfile.mkdtemp() + >>> svm.save(sc, path) + >>> sameModel = SVMModel.load(sc, path) + >>> sameModel.predict(SparseVector(2, {1: 1.0})) + 1 + >>> sameModel.predict(SparseVector(2, {0: -1.0})) + 0 + >>> try: + ... os.removedirs(path) + ... except: + ... pass """ def __init__(self, weights, intercept): super(SVMModel, self).__init__(weights, intercept) @@ -263,6 +303,22 @@ def predict(self, x): else: return 1 if margin > self._threshold else 0 + def save(self, sc, path): + java_model = sc._jvm.org.apache.spark.mllib.classification.SVMModel( + _py2java(sc, self._coeff), self.intercept) + java_model.save(sc._jsc.sc(), path) + + @classmethod + def load(cls, sc, path): + java_model = sc._jvm.org.apache.spark.mllib.classification.SVMModel.load( + sc._jsc.sc(), path) + weights = _java2py(sc, java_model.weights()) + intercept = java_model.intercept() + threshold = java_model.getThreshold().get() + model = SVMModel(weights, intercept) + model.setThreshold(threshold) + return model + class SVMWithSGD(object): diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py index 0c21ad578793f..015a7860116c9 100644 --- a/python/pyspark/mllib/regression.py +++ b/python/pyspark/mllib/regression.py @@ -18,8 +18,9 @@ import numpy as np from numpy import array -from pyspark.mllib.common import callMLlibFunc, inherit_doc +from pyspark.mllib.common import callMLlibFunc, _py2java, _java2py, inherit_doc from pyspark.mllib.linalg import SparseVector, _convert_to_vector +from pyspark.mllib.util import Saveable, Loader __all__ = ['LabeledPoint', 'LinearModel', 'LinearRegressionModel', 'LinearRegressionWithSGD', @@ -114,6 +115,20 @@ class LinearRegressionModel(LinearRegressionModelBase): True >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 True + >>> import os, tempfile + >>> path = tempfile.mkdtemp() + >>> lrm.save(sc, path) + >>> sameModel = LinearRegressionModel.load(sc, path) + >>> abs(sameModel.predict(np.array([0.0])) - 0) < 0.5 + True + >>> abs(sameModel.predict(np.array([1.0])) - 1) < 0.5 + True + >>> abs(sameModel.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 + True + >>> try: + ... os.removedirs(path) + ... except: + ... pass >>> data = [ ... LabeledPoint(0.0, SparseVector(1, {0: 0.0})), ... LabeledPoint(1.0, SparseVector(1, {0: 1.0})), @@ -126,6 +141,19 @@ class LinearRegressionModel(LinearRegressionModelBase): >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 True """ + def save(self, sc, path): + java_model = sc._jvm.org.apache.spark.mllib.regression.LinearRegressionModel( + _py2java(sc, self._coeff), self.intercept) + java_model.save(sc._jsc.sc(), path) + + @classmethod + def load(cls, sc, path): + java_model = sc._jvm.org.apache.spark.mllib.regression.LinearRegressionModel.load( + sc._jsc.sc(), path) + weights = _java2py(sc, java_model.weights()) + intercept = java_model.intercept() + model = LinearRegressionModel(weights, intercept) + return model # train_func should take two parameters, namely data and initial_weights, and @@ -199,6 +227,20 @@ class LassoModel(LinearRegressionModelBase): True >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 True + >>> import os, tempfile + >>> path = tempfile.mkdtemp() + >>> lrm.save(sc, path) + >>> sameModel = LassoModel.load(sc, path) + >>> abs(sameModel.predict(np.array([0.0])) - 0) < 0.5 + True + >>> abs(sameModel.predict(np.array([1.0])) - 1) < 0.5 + True + >>> abs(sameModel.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 + True + >>> try: + ... os.removedirs(path) + ... except: + ... pass >>> data = [ ... LabeledPoint(0.0, SparseVector(1, {0: 0.0})), ... LabeledPoint(1.0, SparseVector(1, {0: 1.0})), @@ -211,6 +253,19 @@ class LassoModel(LinearRegressionModelBase): >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 True """ + def save(self, sc, path): + java_model = sc._jvm.org.apache.spark.mllib.regression.LassoModel( + _py2java(sc, self._coeff), self.intercept) + java_model.save(sc._jsc.sc(), path) + + @classmethod + def load(cls, sc, path): + java_model = sc._jvm.org.apache.spark.mllib.regression.LassoModel.load( + sc._jsc.sc(), path) + weights = _java2py(sc, java_model.weights()) + intercept = java_model.intercept() + model = LassoModel(weights, intercept) + return model class LassoWithSGD(object): @@ -246,6 +301,20 @@ class RidgeRegressionModel(LinearRegressionModelBase): True >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 True + >>> import os, tempfile + >>> path = tempfile.mkdtemp() + >>> lrm.save(sc, path) + >>> sameModel = RidgeRegressionModel.load(sc, path) + >>> abs(sameModel.predict(np.array([0.0])) - 0) < 0.5 + True + >>> abs(sameModel.predict(np.array([1.0])) - 1) < 0.5 + True + >>> abs(sameModel.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 + True + >>> try: + ... os.removedirs(path) + ... except: + ... pass >>> data = [ ... LabeledPoint(0.0, SparseVector(1, {0: 0.0})), ... LabeledPoint(1.0, SparseVector(1, {0: 1.0})), @@ -258,6 +327,19 @@ class RidgeRegressionModel(LinearRegressionModelBase): >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 True """ + def save(self, sc, path): + java_model = sc._jvm.org.apache.spark.mllib.regression.RidgeRegressionModel( + _py2java(sc, self._coeff), self.intercept) + java_model.save(sc._jsc.sc(), path) + + @classmethod + def load(cls, sc, path): + java_model = sc._jvm.org.apache.spark.mllib.regression.RidgeRegressionModel.load( + sc._jsc.sc(), path) + weights = _java2py(sc, java_model.weights()) + intercept = java_model.intercept() + model = RidgeRegressionModel(weights, intercept) + return model class RidgeRegressionWithSGD(object): diff --git a/python/pyspark/mllib/util.py b/python/pyspark/mllib/util.py index e877c720ac77a..c5c3468eb95e9 100644 --- a/python/pyspark/mllib/util.py +++ b/python/pyspark/mllib/util.py @@ -20,7 +20,6 @@ from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper, inherit_doc from pyspark.mllib.linalg import Vectors, SparseVector, _convert_to_vector -from pyspark.mllib.regression import LabeledPoint class MLUtils(object): @@ -50,6 +49,7 @@ def _parse_libsvm_line(line, multiclass=None): @staticmethod def _convert_labeled_point_to_libsvm(p): """Converts a LabeledPoint to a string in LIBSVM format.""" + from pyspark.mllib.regression import LabeledPoint assert isinstance(p, LabeledPoint) items = [str(p.label)] v = _convert_to_vector(p.features) @@ -92,6 +92,7 @@ def loadLibSVMFile(sc, path, numFeatures=-1, minPartitions=None, multiclass=None >>> from tempfile import NamedTemporaryFile >>> from pyspark.mllib.util import MLUtils + >>> from pyspark.mllib.regression import LabeledPoint >>> tempFile = NamedTemporaryFile(delete=True) >>> tempFile.write("+1 1:1.0 3:2.0 5:3.0\\n-1\\n-1 2:4.0 4:5.0 6:6.0") >>> tempFile.flush() @@ -110,6 +111,7 @@ def loadLibSVMFile(sc, path, numFeatures=-1, minPartitions=None, multiclass=None >>> print examples[2] (-1.0,(6,[1,3,5],[4.0,5.0,6.0])) """ + from pyspark.mllib.regression import LabeledPoint if multiclass is not None: warnings.warn("deprecated", DeprecationWarning) @@ -130,6 +132,7 @@ def saveAsLibSVMFile(data, dir): >>> from tempfile import NamedTemporaryFile >>> from fileinput import input + >>> from pyspark.mllib.regression import LabeledPoint >>> from glob import glob >>> from pyspark.mllib.util import MLUtils >>> examples = [LabeledPoint(1.1, Vectors.sparse(3, [(0, 1.23), (2, 4.56)])), \ @@ -156,6 +159,7 @@ def loadLabeledPoints(sc, path, minPartitions=None): >>> from tempfile import NamedTemporaryFile >>> from pyspark.mllib.util import MLUtils + >>> from pyspark.mllib.regression import LabeledPoint >>> examples = [LabeledPoint(1.1, Vectors.sparse(3, [(0, -1.23), (2, 4.56e-7)])), \ LabeledPoint(0.0, Vectors.dense([1.01, 2.02, 3.03]))] >>> tempFile = NamedTemporaryFile(delete=True) From 5e6ad24ff645a9b0f63d9c0f17193550963aa0a7 Mon Sep 17 00:00:00 2001 From: Shuo Xiang Date: Fri, 20 Mar 2015 14:45:44 -0400 Subject: [PATCH 16/47] [MLlib] SPARK-5954: Top by key This PR implements two functions - `topByKey(num: Int): RDD[(K, Array[V])]` finds the top-k values for each key in a pair RDD. This can be used, e.g., in computing top recommendations. - `takeOrderedByKey(num: Int): RDD[(K, Array[V])] ` does the opposite of `topByKey` The `sorted` is used here as the `toArray` method of the PriorityQueue does not return a necessarily sorted array. Author: Shuo Xiang Closes #5075 from coderxiang/topByKey and squashes the following commits: 1611c37 [Shuo Xiang] code clean up 6f565c0 [Shuo Xiang] naming a80e0ec [Shuo Xiang] typo and warning 82dded9 [Shuo Xiang] Merge remote-tracking branch 'upstream/master' into topByKey d202745 [Shuo Xiang] move to MLPairRDDFunctions 901b0af [Shuo Xiang] style check 70c6e35 [Shuo Xiang] remove takeOrderedByKey, update doc and test 0895c17 [Shuo Xiang] Merge remote-tracking branch 'upstream/master' into topByKey b10e325 [Shuo Xiang] Merge remote-tracking branch 'upstream/master' into topByKey debccad [Shuo Xiang] topByKey --- .../spark/mllib/rdd/MLPairRDDFunctions.scala | 60 +++++++++++++++++++ .../mllib/rdd/MLPairRDDFunctionsSuite.scala | 36 +++++++++++ 2 files changed, 96 insertions(+) create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/rdd/MLPairRDDFunctions.scala create mode 100644 mllib/src/test/scala/org/apache/spark/mllib/rdd/MLPairRDDFunctionsSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/mllib/rdd/MLPairRDDFunctions.scala b/mllib/src/main/scala/org/apache/spark/mllib/rdd/MLPairRDDFunctions.scala new file mode 100644 index 0000000000000..9213fd3f595c3 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/rdd/MLPairRDDFunctions.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.rdd + +import scala.language.implicitConversions +import scala.reflect.ClassTag + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.rdd.RDD +import org.apache.spark.util.BoundedPriorityQueue + +/** + * Machine learning specific Pair RDD functions. + */ +@DeveloperApi +class MLPairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)]) extends Serializable { + /** + * Returns the top k (largest) elements for each key from this RDD as defined by the specified + * implicit Ordering[T]. + * If the number of elements for a certain key is less than k, all of them will be returned. + * + * @param num k, the number of top elements to return + * @param ord the implicit ordering for T + * @return an RDD that contains the top k values for each key + */ + def topByKey(num: Int)(implicit ord: Ordering[V]): RDD[(K, Array[V])] = { + self.aggregateByKey(new BoundedPriorityQueue[V](num)(ord))( + seqOp = (queue, item) => { + queue += item + queue + }, + combOp = (queue1, queue2) => { + queue1 ++= queue2 + queue1 + } + ).mapValues(_.toArray.sorted(ord.reverse)) + } +} + +@DeveloperApi +object MLPairRDDFunctions { + /** Implicit conversion from a pair RDD to MLPairRDDFunctions. */ + implicit def fromPairRDD[K: ClassTag, V: ClassTag](rdd: RDD[(K, V)]): MLPairRDDFunctions[K, V] = + new MLPairRDDFunctions[K, V](rdd) +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/rdd/MLPairRDDFunctionsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/rdd/MLPairRDDFunctionsSuite.scala new file mode 100644 index 0000000000000..1ac7c12c4e8e6 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/rdd/MLPairRDDFunctionsSuite.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.rdd + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.rdd.MLPairRDDFunctions._ + +class MLPairRDDFunctionsSuite extends FunSuite with MLlibTestSparkContext { + test("topByKey") { + val topMap = sc.parallelize(Array((1, 1), (1, 2), (3, 2), (3, 7), (3, 5), (5, 1), (5, 3)), 2) + .topByKey(2) + .collectAsMap() + + assert(topMap.size === 3) + assert(topMap(1) === Array(2, 1)) + assert(topMap(3) === Array(7, 5)) + assert(topMap(5) === Array(3, 1)) + } +} From 25636d9867c6bc901463b6b227cb444d701cfdd1 Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Fri, 20 Mar 2015 14:53:59 -0400 Subject: [PATCH 17/47] [Spark 6096][MLlib] Add Naive Bayes load save methods in Python See [SPARK-6096](https://issues.apache.org/jira/browse/SPARK-6096). Author: Xusen Yin Closes #5090 from yinxusen/SPARK-6096 and squashes the following commits: bd0fea5 [Xusen Yin] fix style problem, etc. 3fd41f2 [Xusen Yin] use hanging indent in Python style e83803d [Xusen Yin] fix Python style d6dbde5 [Xusen Yin] fix python call java error a054bb3 [Xusen Yin] add save load for NaiveBayes python --- .../mllib/classification/NaiveBayes.scala | 11 +++++++ python/pyspark/mllib/classification.py | 31 ++++++++++++++++++- 2 files changed, 41 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala index 068449aa1d346..d60e82c410979 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -17,6 +17,10 @@ package org.apache.spark.mllib.classification +import java.lang.{Iterable => JIterable} + +import scala.collection.JavaConverters._ + import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum} import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ @@ -41,6 +45,13 @@ class NaiveBayesModel private[mllib] ( val pi: Array[Double], val theta: Array[Array[Double]]) extends ClassificationModel with Serializable with Saveable { + /** A Java-friendly constructor that takes three Iterable parameters. */ + private[mllib] def this( + labels: JIterable[Double], + pi: JIterable[Double], + theta: JIterable[JIterable[Double]]) = + this(labels.asScala.toArray, pi.asScala.toArray, theta.asScala.toArray.map(_.asScala.toArray)) + private val brzPi = new BDV[Double](pi) private val brzTheta = new BDM[Double](theta.length, theta(0).length) diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py index b66159c5bfb66..6766f3ebb8894 100644 --- a/python/pyspark/mllib/classification.py +++ b/python/pyspark/mllib/classification.py @@ -24,6 +24,7 @@ from pyspark.mllib.common import callMLlibFunc, _py2java, _java2py from pyspark.mllib.linalg import SparseVector, _convert_to_vector from pyspark.mllib.regression import LabeledPoint, LinearModel, _regression_train_wrapper +from pyspark.mllib.util import Saveable, Loader, inherit_doc __all__ = ['LogisticRegressionModel', 'LogisticRegressionWithSGD', 'LogisticRegressionWithLBFGS', @@ -359,7 +360,8 @@ def train(rdd, i): return _regression_train_wrapper(train, SVMModel, data, initialWeights) -class NaiveBayesModel(object): +@inherit_doc +class NaiveBayesModel(Saveable, Loader): """ Model for Naive Bayes classifiers. @@ -390,6 +392,16 @@ class NaiveBayesModel(object): 0.0 >>> model.predict(SparseVector(2, {0: 1.0})) 1.0 + >>> import os, tempfile + >>> path = tempfile.mkdtemp() + >>> model.save(sc, path) + >>> sameModel = NaiveBayesModel.load(sc, path) + >>> sameModel.predict(SparseVector(2, {0: 1.0})) == model.predict(SparseVector(2, {0: 1.0})) + True + >>> try: + ... os.removedirs(path) + ... except OSError: + ... pass """ def __init__(self, labels, pi, theta): @@ -404,6 +416,23 @@ def predict(self, x): x = _convert_to_vector(x) return self.labels[numpy.argmax(self.pi + x.dot(self.theta.transpose()))] + def save(self, sc, path): + java_labels = _py2java(sc, self.labels.tolist()) + java_pi = _py2java(sc, self.pi.tolist()) + java_theta = _py2java(sc, self.theta.tolist()) + java_model = sc._jvm.org.apache.spark.mllib.classification.NaiveBayesModel( + java_labels, java_pi, java_theta) + java_model.save(sc._jsc.sc(), path) + + @classmethod + def load(cls, sc, path): + java_model = sc._jvm.org.apache.spark.mllib.classification.NaiveBayesModel.load( + sc._jsc.sc(), path) + py_labels = _java2py(sc, java_model.labels()) + py_pi = _java2py(sc, java_model.pi()) + py_theta = _java2py(sc, java_model.theta()) + return NaiveBayesModel(py_labels, py_pi, numpy.array(py_theta)) + class NaiveBayes(object): From 6b36470c66bd6140c45e45d3f1d51b0082c3fd97 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Fri, 20 Mar 2015 15:02:57 -0400 Subject: [PATCH 18/47] [SPARK-5955][MLLIB] add checkpointInterval to ALS Add checkpiontInterval to ALS to prevent: 1. StackOverflow exceptions caused by long lineage, 2. large shuffle files generated during iterations, 3. slow recovery when some node fail. srowen coderxiang Author: Xiangrui Meng Closes #5076 from mengxr/SPARK-5955 and squashes the following commits: df56791 [Xiangrui Meng] update impl to reuse code 29affcb [Xiangrui Meng] do not materialize factors in implicit 20d3f7f [Xiangrui Meng] add checkpointInterval to ALS --- .../apache/spark/ml/param/sharedParams.scala | 11 +++++ .../apache/spark/ml/recommendation/ALS.scala | 42 ++++++++++++++++--- .../spark/mllib/recommendation/ALS.scala | 17 ++++++++ .../spark/ml/recommendation/ALSSuite.scala | 17 ++++++++ 4 files changed, 82 insertions(+), 5 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala index 1a70322b4cace..5d660d1e151a7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala @@ -138,3 +138,14 @@ private[ml] trait HasOutputCol extends Params { /** @group getParam */ def getOutputCol: String = get(outputCol) } + +private[ml] trait HasCheckpointInterval extends Params { + /** + * param for checkpoint interval + * @group param + */ + val checkpointInterval: IntParam = new IntParam(this, "checkpointInterval", "checkpoint interval") + + /** @group getParam */ + def getCheckpointInterval: Int = get(checkpointInterval) +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index e3515ee81af3d..514b4ef98dc5b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -18,6 +18,7 @@ package org.apache.spark.ml.recommendation import java.{util => ju} +import java.io.IOException import scala.collection.mutable import scala.reflect.ClassTag @@ -26,6 +27,7 @@ import scala.util.hashing.byteswap64 import com.github.fommil.netlib.BLAS.{getInstance => blas} import com.github.fommil.netlib.LAPACK.{getInstance => lapack} +import org.apache.hadoop.fs.{FileSystem, Path} import org.netlib.util.intW import org.apache.spark.{Logging, Partitioner} @@ -46,7 +48,7 @@ import org.apache.spark.util.random.XORShiftRandom * Common params for ALS. */ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasRegParam - with HasPredictionCol { + with HasPredictionCol with HasCheckpointInterval { /** * Param for rank of the matrix factorization. @@ -164,6 +166,7 @@ class ALSModel private[ml] ( itemFactors: RDD[(Int, Array[Float])]) extends Model[ALSModel] with ALSParams { + /** @group setParam */ def setPredictionCol(value: String): this.type = set(predictionCol, value) override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { @@ -262,6 +265,9 @@ class ALS extends Estimator[ALSModel] with ALSParams { /** @group setParam */ def setNonnegative(value: Boolean): this.type = set(nonnegative, value) + /** @group setParam */ + def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) + /** * Sets both numUserBlocks and numItemBlocks to the specific value. * @group setParam @@ -274,6 +280,7 @@ class ALS extends Estimator[ALSModel] with ALSParams { setMaxIter(20) setRegParam(1.0) + setCheckpointInterval(10) override def fit(dataset: DataFrame, paramMap: ParamMap): ALSModel = { val map = this.paramMap ++ paramMap @@ -285,7 +292,8 @@ class ALS extends Estimator[ALSModel] with ALSParams { val (userFactors, itemFactors) = ALS.train(ratings, rank = map(rank), numUserBlocks = map(numUserBlocks), numItemBlocks = map(numItemBlocks), maxIter = map(maxIter), regParam = map(regParam), implicitPrefs = map(implicitPrefs), - alpha = map(alpha), nonnegative = map(nonnegative)) + alpha = map(alpha), nonnegative = map(nonnegative), + checkpointInterval = map(checkpointInterval)) val model = new ALSModel(this, map, map(rank), userFactors, itemFactors) Params.inheritValues(map, this, model) model @@ -494,6 +502,7 @@ object ALS extends Logging { nonnegative: Boolean = false, intermediateRDDStorageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK, finalRDDStorageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK, + checkpointInterval: Int = 10, seed: Long = 0L)( implicit ord: Ordering[ID]): (RDD[(ID, Array[Float])], RDD[(ID, Array[Float])]) = { require(intermediateRDDStorageLevel != StorageLevel.NONE, @@ -521,6 +530,18 @@ object ALS extends Logging { val seedGen = new XORShiftRandom(seed) var userFactors = initialize(userInBlocks, rank, seedGen.nextLong()) var itemFactors = initialize(itemInBlocks, rank, seedGen.nextLong()) + var previousCheckpointFile: Option[String] = None + val shouldCheckpoint: Int => Boolean = (iter) => + sc.checkpointDir.isDefined && (iter % checkpointInterval == 0) + val deletePreviousCheckpointFile: () => Unit = () => + previousCheckpointFile.foreach { file => + try { + FileSystem.get(sc.hadoopConfiguration).delete(new Path(file), true) + } catch { + case e: IOException => + logWarning(s"Cannot delete checkpoint file $file:", e) + } + } if (implicitPrefs) { for (iter <- 1 to maxIter) { userFactors.setName(s"userFactors-$iter").persist(intermediateRDDStorageLevel) @@ -528,19 +549,30 @@ object ALS extends Logging { itemFactors = computeFactors(userFactors, userOutBlocks, itemInBlocks, rank, regParam, userLocalIndexEncoder, implicitPrefs, alpha, solver) previousItemFactors.unpersist() - if (sc.checkpointDir.isDefined && (iter % 3 == 0)) { - itemFactors.checkpoint() - } itemFactors.setName(s"itemFactors-$iter").persist(intermediateRDDStorageLevel) + // TODO: Generalize PeriodicGraphCheckpointer and use it here. + if (shouldCheckpoint(iter)) { + itemFactors.checkpoint() // itemFactors gets materialized in computeFactors. + } val previousUserFactors = userFactors userFactors = computeFactors(itemFactors, itemOutBlocks, userInBlocks, rank, regParam, itemLocalIndexEncoder, implicitPrefs, alpha, solver) + if (shouldCheckpoint(iter)) { + deletePreviousCheckpointFile() + previousCheckpointFile = itemFactors.getCheckpointFile + } previousUserFactors.unpersist() } } else { for (iter <- 0 until maxIter) { itemFactors = computeFactors(userFactors, userOutBlocks, itemInBlocks, rank, regParam, userLocalIndexEncoder, solver = solver) + if (shouldCheckpoint(iter)) { + itemFactors.checkpoint() + itemFactors.count() // checkpoint item factors and cut lineage + deletePreviousCheckpointFile() + previousCheckpointFile = itemFactors.getCheckpointFile + } userFactors = computeFactors(itemFactors, itemOutBlocks, userInBlocks, rank, regParam, itemLocalIndexEncoder, solver = solver) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala index caacab943030b..dddefe1944e9d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala @@ -82,6 +82,9 @@ class ALS private ( private var intermediateRDDStorageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK private var finalRDDStorageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK + /** checkpoint interval */ + private var checkpointInterval: Int = 10 + /** * Set the number of blocks for both user blocks and product blocks to parallelize the computation * into; pass -1 for an auto-configured number of blocks. Default: -1. @@ -182,6 +185,19 @@ class ALS private ( this } + /** + * Set period (in iterations) between checkpoints (default = 10). Checkpointing helps with + * recovery (when nodes fail) and StackOverflow exceptions caused by long lineage. It also helps + * with eliminating temporary shuffle files on disk, which can be important when there are many + * ALS iterations. If the checkpoint directory is not set in [[org.apache.spark.SparkContext]], + * this setting is ignored. + */ + @DeveloperApi + def setCheckpointInterval(checkpointInterval: Int): this.type = { + this.checkpointInterval = checkpointInterval + this + } + /** * Run ALS with the configured parameters on an input RDD of (user, product, rating) triples. * Returns a MatrixFactorizationModel with feature vectors for each user and product. @@ -212,6 +228,7 @@ class ALS private ( nonnegative = nonnegative, intermediateRDDStorageLevel = intermediateRDDStorageLevel, finalRDDStorageLevel = StorageLevel.NONE, + checkpointInterval = checkpointInterval, seed = seed) val userFactors = floatUserFactors diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala index bb86bafc0eb0a..0bb06e9e8ac9c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.ml.recommendation +import java.io.File import java.util.Random import scala.collection.mutable @@ -32,16 +33,25 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.rdd.RDD import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.util.Utils class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging { private var sqlContext: SQLContext = _ + private var tempDir: File = _ override def beforeAll(): Unit = { super.beforeAll() + tempDir = Utils.createTempDir() + sc.setCheckpointDir(tempDir.getAbsolutePath) sqlContext = new SQLContext(sc) } + override def afterAll(): Unit = { + Utils.deleteRecursively(tempDir) + super.afterAll() + } + test("LocalIndexEncoder") { val random = new Random for (numBlocks <- Seq(1, 2, 5, 10, 20, 50, 100)) { @@ -485,4 +495,11 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging { }.count() } } + + test("als with large number of iterations") { + val (ratings, _) = genExplicitTestData(numUsers = 4, numItems = 4, rank = 1) + ALS.train(ratings, rank = 1, maxIter = 50, numUserBlocks = 2, numItemBlocks = 2) + ALS.train( + ratings, rank = 1, maxIter = 50, numUserBlocks = 2, numItemBlocks = 2, implicitPrefs = true) + } } From 49a01c7ea2c48feee7ab4551c4fa03fd1cdb1a32 Mon Sep 17 00:00:00 2001 From: Jongyoul Lee Date: Fri, 20 Mar 2015 19:14:35 +0000 Subject: [PATCH 19/47] [SPARK-6423][Mesos] MemoryUtils should use memoryOverhead if it's set - Fixed calculateTotalMemory to use spark.mesos.executor.memoryOverhead - Added testCase Author: Jongyoul Lee Closes #5099 from jongyoul/SPARK-6423 and squashes the following commits: 6747fce [Jongyoul Lee] [SPARK-6423][Mesos] MemoryUtils should use memoryOverhead if it's set - Changed a description of spark.mesos.executor.memoryOverhead 475a7c8 [Jongyoul Lee] [SPARK-6423][Mesos] MemoryUtils should use memoryOverhead if it's set - Fit the import rules 453c5a2 [Jongyoul Lee] [SPARK-6423][Mesos] MemoryUtils should use memoryOverhead if it's set - Fixed calculateTotalMemory to use spark.mesos.executor.memoryOverhead - Added testCase --- .../scheduler/cluster/mesos/MemoryUtils.scala | 10 ++-- .../cluster/mesos/MemoryUtilsSuite.scala | 47 +++++++++++++++++++ docs/running-on-mesos.md | 8 ++-- 3 files changed, 53 insertions(+), 12 deletions(-) create mode 100644 core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MemoryUtilsSuite.scala diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MemoryUtils.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MemoryUtils.scala index 705116cb13f54..aa3ec0f8cfb9c 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MemoryUtils.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MemoryUtils.scala @@ -21,15 +21,11 @@ import org.apache.spark.SparkContext private[spark] object MemoryUtils { // These defaults copied from YARN - val OVERHEAD_FRACTION = 1.10 + val OVERHEAD_FRACTION = 0.10 val OVERHEAD_MINIMUM = 384 def calculateTotalMemory(sc: SparkContext) = { - math.max( - sc.conf.getOption("spark.mesos.executor.memoryOverhead") - .getOrElse(OVERHEAD_MINIMUM.toString) - .toInt + sc.executorMemory, - OVERHEAD_FRACTION * sc.executorMemory - ) + sc.conf.getInt("spark.mesos.executor.memoryOverhead", + math.max(OVERHEAD_FRACTION * sc.executorMemory, OVERHEAD_MINIMUM).toInt) + sc.executorMemory } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MemoryUtilsSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MemoryUtilsSuite.scala new file mode 100644 index 0000000000000..3fa0115e68259 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MemoryUtilsSuite.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster.mesos + +import org.mockito.Mockito._ +import org.scalatest.FunSuite +import org.scalatest.mock.MockitoSugar + +import org.apache.spark.{SparkConf, SparkContext} + +class MemoryUtilsSuite extends FunSuite with MockitoSugar { + test("MesosMemoryUtils should always override memoryOverhead when it's set") { + val sparkConf = new SparkConf + + val sc = mock[SparkContext] + when(sc.conf).thenReturn(sparkConf) + + // 384 > sc.executorMemory * 0.1 => 512 + 384 = 896 + when(sc.executorMemory).thenReturn(512) + assert(MemoryUtils.calculateTotalMemory(sc) === 896) + + // 384 < sc.executorMemory * 0.1 => 4096 + (4096 * 0.1) = 4505.6 + when(sc.executorMemory).thenReturn(4096) + assert(MemoryUtils.calculateTotalMemory(sc) === 4505) + + // set memoryOverhead + sparkConf.set("spark.mesos.executor.memoryOverhead", "100") + assert(MemoryUtils.calculateTotalMemory(sc) === 4196) + sparkConf.set("spark.mesos.executor.memoryOverhead", "400") + assert(MemoryUtils.calculateTotalMemory(sc) === 4496) + } +} diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index 6a9d304501dc0..c984639bd34cf 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -224,11 +224,9 @@ See the [configuration page](configuration.html) for information on Spark config spark.mesos.executor.memoryOverhead executor memory * 0.10, with minimum of 384 - This value is an additive for spark.executor.memory, specified in MB, - which is used to calculate the total Mesos task memory. A value of 384 - implies a 384MB overhead. Additionally, there is a hard-coded 10% minimum - overhead. The final overhead will be the larger of either - `spark.mesos.executor.memoryOverhead` or 10% of `spark.executor.memory`. + The amount of additional memory, specified in MB, to be allocated per executor. By default, + the overhead will be larger of either 384 or 10% of `spark.executor.memory`. If it's set, + the final overhead will be this value. From 11e025956be3818c00effef0d650734f8feeb436 Mon Sep 17 00:00:00 2001 From: MechCoder Date: Fri, 20 Mar 2015 17:13:18 -0400 Subject: [PATCH 20/47] [SPARK-6309] [SQL] [MLlib] Implement MatrixUDT Utilities to serialize and deserialize Matrices in MLlib Author: MechCoder Closes #5048 from MechCoder/spark-6309 and squashes the following commits: 05dc6f2 [MechCoder] Hashcode and organize imports 16d5d47 [MechCoder] Test some more 6e67020 [MechCoder] TST: Test using Array conversion instead of equals 7fa7a2c [MechCoder] [SPARK-6309] [SQL] [MLlib] Implement MatrixUDT --- .../apache/spark/mllib/linalg/Matrices.scala | 90 +++++++++++++++++++ .../spark/mllib/linalg/MatricesSuite.scala | 13 +++ 2 files changed, 103 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala index fdd8848189f19..849f44295f089 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala @@ -23,9 +23,15 @@ import scala.collection.mutable.{ArrayBuilder => MArrayBuilder, HashSet => MHash import breeze.linalg.{CSCMatrix => BSM, DenseMatrix => BDM, Matrix => BM} +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.Row +import org.apache.spark.sql.types._ +import org.apache.spark.sql.catalyst.expressions.GenericMutableRow + /** * Trait for a local matrix. */ +@SQLUserDefinedType(udt = classOf[MatrixUDT]) sealed trait Matrix extends Serializable { /** Number of rows. */ @@ -102,6 +108,88 @@ sealed trait Matrix extends Serializable { private[spark] def foreachActive(f: (Int, Int, Double) => Unit) } +@DeveloperApi +private[spark] class MatrixUDT extends UserDefinedType[Matrix] { + + override def sqlType: StructType = { + // type: 0 = sparse, 1 = dense + // the dense matrix is built by numRows, numCols, values and isTransposed, all of which are + // set as not nullable, except values since in the future, support for binary matrices might + // be added for which values are not needed. + // the sparse matrix needs colPtrs and rowIndices, which are set as + // null, while building the dense matrix. + StructType(Seq( + StructField("type", ByteType, nullable = false), + StructField("numRows", IntegerType, nullable = false), + StructField("numCols", IntegerType, nullable = false), + StructField("colPtrs", ArrayType(IntegerType, containsNull = false), nullable = true), + StructField("rowIndices", ArrayType(IntegerType, containsNull = false), nullable = true), + StructField("values", ArrayType(DoubleType, containsNull = false), nullable = true), + StructField("isTransposed", BooleanType, nullable = false) + )) + } + + override def serialize(obj: Any): Row = { + val row = new GenericMutableRow(7) + obj match { + case sm: SparseMatrix => + row.setByte(0, 0) + row.setInt(1, sm.numRows) + row.setInt(2, sm.numCols) + row.update(3, sm.colPtrs.toSeq) + row.update(4, sm.rowIndices.toSeq) + row.update(5, sm.values.toSeq) + row.setBoolean(6, sm.isTransposed) + + case dm: DenseMatrix => + row.setByte(0, 1) + row.setInt(1, dm.numRows) + row.setInt(2, dm.numCols) + row.setNullAt(3) + row.setNullAt(4) + row.update(5, dm.values.toSeq) + row.setBoolean(6, dm.isTransposed) + } + row + } + + override def deserialize(datum: Any): Matrix = { + datum match { + // TODO: something wrong with UDT serialization, should never happen. + case m: Matrix => m + case row: Row => + require(row.length == 7, + s"MatrixUDT.deserialize given row with length ${row.length} but requires length == 7") + val tpe = row.getByte(0) + val numRows = row.getInt(1) + val numCols = row.getInt(2) + val values = row.getAs[Iterable[Double]](5).toArray + val isTransposed = row.getBoolean(6) + tpe match { + case 0 => + val colPtrs = row.getAs[Iterable[Int]](3).toArray + val rowIndices = row.getAs[Iterable[Int]](4).toArray + new SparseMatrix(numRows, numCols, colPtrs, rowIndices, values, isTransposed) + case 1 => + new DenseMatrix(numRows, numCols, values, isTransposed) + } + } + } + + override def userClass: Class[Matrix] = classOf[Matrix] + + override def equals(o: Any): Boolean = { + o match { + case v: MatrixUDT => true + case _ => false + } + } + + override def hashCode(): Int = 1994 + + private[spark] override def asNullable: MatrixUDT = this +} + /** * Column-major dense matrix. * The entry values are stored in a single array of doubles with columns listed in sequence. @@ -119,6 +207,7 @@ sealed trait Matrix extends Serializable { * @param isTransposed whether the matrix is transposed. If true, `values` stores the matrix in * row major. */ +@SQLUserDefinedType(udt = classOf[MatrixUDT]) class DenseMatrix( val numRows: Int, val numCols: Int, @@ -360,6 +449,7 @@ object DenseMatrix { * Compressed Sparse Row (CSR) format, where `colPtrs` behaves as rowPtrs, * and `rowIndices` behave as colIndices, and `values` are stored in row major. */ +@SQLUserDefinedType(udt = classOf[MatrixUDT]) class SparseMatrix( val numRows: Int, val numCols: Int, diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala index c098b5458fe6b..96f677db3f377 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala @@ -424,4 +424,17 @@ class MatricesSuite extends FunSuite { assert(mat.rowIndices.toSeq === Seq(3, 0, 2, 1)) assert(mat.values.toSeq === Seq(1.0, 2.0, 3.0, 4.0)) } + + test("MatrixUDT") { + val dm1 = new DenseMatrix(2, 2, Array(0.9, 1.2, 2.3, 9.8)) + val dm2 = new DenseMatrix(3, 2, Array(0.0, 1.21, 2.3, 9.8, 9.0, 0.0)) + val dm3 = new DenseMatrix(0, 0, Array()) + val sm1 = dm1.toSparse + val sm2 = dm2.toSparse + val sm3 = dm3.toSparse + val mUDT = new MatrixUDT() + Seq(dm1, dm2, dm3, sm1, sm2, sm3).foreach { + mat => assert(mat.toArray === mUDT.deserialize(mUDT.serialize(mat)).toArray) + } + } } From 257cde7c363efb3317bfb5c13975cca9154894e2 Mon Sep 17 00:00:00 2001 From: lewuathe Date: Fri, 20 Mar 2015 17:18:18 -0400 Subject: [PATCH 21/47] [SPARK-6421][MLLIB] _regression_train_wrapper does not test initialWeights correctly Weight parameters must be initialized correctly even when numpy array is passed as initial weights. Author: lewuathe Closes #5101 from Lewuathe/SPARK-6421 and squashes the following commits: 7795201 [lewuathe] Fix lint-python errors 21d4fe3 [lewuathe] Fix init logic of weights --- python/pyspark/mllib/regression.py | 3 ++- python/pyspark/mllib/tests.py | 7 +++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py index 015a7860116c9..414a0ada80787 100644 --- a/python/pyspark/mllib/regression.py +++ b/python/pyspark/mllib/regression.py @@ -163,7 +163,8 @@ def _regression_train_wrapper(train_func, modelClass, data, initial_weights): first = data.first() if not isinstance(first, LabeledPoint): raise ValueError("data should be an RDD of LabeledPoint, but got %s" % first) - initial_weights = initial_weights or [0.0] * len(data.first().features) + if initial_weights is None: + initial_weights = [0.0] * len(data.first().features) weights, intercept = train_func(data, _convert_to_vector(initial_weights)) return modelClass(weights, intercept) diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 5328d99b69684..155019638f806 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -323,6 +323,13 @@ def test_regression(self): self.assertTrue(gbt_model.predict(features[2]) <= 0) self.assertTrue(gbt_model.predict(features[3]) > 0) + try: + LinearRegressionWithSGD.train(rdd, initialWeights=array([1.0, 1.0])) + LassoWithSGD.train(rdd, initialWeights=array([1.0, 1.0])) + RidgeRegressionWithSGD.train(rdd, initialWeights=array([1.0, 1.0])) + except ValueError: + self.fail() + class StatTests(PySparkTestCase): # SPARK-4023 From a95043b1780bfde556db2dcc01511e40a12498dd Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 20 Mar 2015 15:47:07 -0700 Subject: [PATCH 22/47] [SPARK-6428][SQL] Added explicit type for all public methods in sql/core Also implemented equals/hashCode when they are missing. This is done in order to enable automatic public method type checking. Author: Reynold Xin Closes #5104 from rxin/sql-hashcode-explicittype and squashes the following commits: ffce6f3 [Reynold Xin] Code review feedback. 8b36733 [Reynold Xin] [SPARK-6428][SQL] Added explicit type for all public methods. --- .../catalyst/expressions/AttributeSet.scala | 3 +- .../spark/sql/catalyst/expressions/rows.scala | 21 +++++ .../org/apache/spark/sql/types/Decimal.scala | 2 +- .../apache/spark/sql/types/dataTypes.scala | 20 ++--- .../scala/org/apache/spark/sql/Column.scala | 2 +- .../org/apache/spark/sql/DataFrame.scala | 6 +- .../org/apache/spark/sql/SQLContext.scala | 8 +- .../apache/spark/sql/UDFRegistration.scala | 2 +- .../spark/sql/columnar/ColumnAccessor.scala | 4 +- .../spark/sql/columnar/ColumnBuilder.scala | 4 +- .../spark/sql/columnar/ColumnStats.scala | 24 +++--- .../spark/sql/columnar/ColumnType.scala | 56 +++++++------ .../columnar/InMemoryColumnarTableScan.scala | 56 +++++++------ .../sql/columnar/NullableColumnAccessor.scala | 2 +- .../CompressibleColumnAccessor.scala | 4 +- .../CompressibleColumnBuilder.scala | 2 +- .../compression/compressionSchemes.scala | 80 ++++++++++--------- .../spark/sql/execution/Aggregate.scala | 8 +- .../apache/spark/sql/execution/Exchange.scala | 19 ++--- .../spark/sql/execution/ExistingRDD.scala | 20 ++--- .../apache/spark/sql/execution/Expand.scala | 5 +- .../apache/spark/sql/execution/Generate.scala | 3 +- .../sql/execution/GeneratedAggregate.scala | 11 +-- .../spark/sql/execution/LocalTableScan.scala | 7 +- .../spark/sql/execution/SparkPlan.scala | 1 + .../sql/execution/SparkSqlSerializer.scala | 2 +- .../spark/sql/execution/SparkStrategies.scala | 6 +- .../spark/sql/execution/basicOperators.scala | 73 +++++++++-------- .../apache/spark/sql/execution/commands.scala | 33 ++++---- .../spark/sql/execution/debug/package.scala | 30 +++---- .../execution/joins/BroadcastHashJoin.scala | 10 ++- .../joins/BroadcastLeftSemiJoinHash.scala | 10 +-- .../joins/BroadcastNestedLoopJoin.scala | 5 +- .../execution/joins/CartesianProduct.scala | 8 +- .../spark/sql/execution/joins/HashJoin.scala | 4 +- .../sql/execution/joins/HashOuterJoin.scala | 53 ++++++------ .../sql/execution/joins/HashedRelation.scala | 4 +- .../sql/execution/joins/LeftSemiJoinBNL.scala | 10 ++- .../execution/joins/LeftSemiJoinHash.scala | 11 +-- .../execution/joins/ShuffledHashJoin.scala | 6 +- .../spark/sql/execution/pythonUdfs.scala | 21 +++-- .../org/apache/spark/sql/jdbc/JDBCRDD.scala | 3 +- .../apache/spark/sql/jdbc/JDBCRelation.scala | 8 +- .../apache/spark/sql/json/JSONRelation.scala | 6 +- .../spark/sql/parquet/ParquetConverter.scala | 2 +- .../spark/sql/parquet/ParquetRelation.scala | 10 ++- .../sql/parquet/ParquetTableOperations.scala | 12 +-- .../apache/spark/sql/parquet/newParquet.scala | 37 ++++++--- .../sql/parquet/timestamp/NanoTime.scala | 6 +- .../spark/sql/sources/LogicalRelation.scala | 16 ++-- .../apache/spark/sql/sources/commands.scala | 2 +- .../org/apache/spark/sql/sources/ddl.scala | 6 +- .../org/apache/spark/sql/sources/rules.scala | 4 +- 53 files changed, 438 insertions(+), 330 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala index a9ba0be596349..adaeab0b5c027 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.analysis.Star protected class AttributeEquals(val a: Attribute) { override def hashCode() = a match { @@ -115,7 +114,7 @@ class AttributeSet private (val baseSet: Set[AttributeEquals]) // sorts of things in its closure. override def toSeq: Seq[Attribute] = baseSet.map(_.a).toArray.toSeq - override def toString = "{" + baseSet.map(_.a).mkString(", ") + "}" + override def toString: String = "{" + baseSet.map(_.a).mkString(", ") + "}" override def isEmpty: Boolean = baseSet.isEmpty } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index faa366771824b..f03d6f71a9fae 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -146,6 +146,27 @@ class GenericRow(protected[sql] val values: Array[Any]) extends Row { result } + override def equals(o: Any): Boolean = o match { + case other: Row => + if (values.length != other.length) { + return false + } + + var i = 0 + while (i < values.length) { + if (isNullAt(i) != other.isNullAt(i)) { + return false + } + if (apply(i) != other.apply(i)) { + return false + } + i += 1 + } + true + + case _ => false + } + def copy() = this } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index 21cc6cea4bf54..994c5202c15dc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -246,7 +246,7 @@ final class Decimal extends Ordered[Decimal] with Serializable { } } - override def equals(other: Any) = other match { + override def equals(other: Any): Boolean = other match { case d: Decimal => compare(d) == 0 case _ => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala index bf39603d13bd5..d973144de3468 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala @@ -307,7 +307,7 @@ protected[sql] object NativeType { protected[sql] trait PrimitiveType extends DataType { - override def isPrimitive = true + override def isPrimitive: Boolean = true } @@ -442,7 +442,7 @@ class TimestampType private() extends NativeType { @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } private[sql] val ordering = new Ordering[JvmType] { - def compare(x: Timestamp, y: Timestamp) = x.compareTo(y) + def compare(x: Timestamp, y: Timestamp): Int = x.compareTo(y) } /** @@ -542,7 +542,7 @@ class LongType private() extends IntegralType { */ override def defaultSize: Int = 8 - override def simpleString = "bigint" + override def simpleString: String = "bigint" private[spark] override def asNullable: LongType = this } @@ -572,7 +572,7 @@ class IntegerType private() extends IntegralType { */ override def defaultSize: Int = 4 - override def simpleString = "int" + override def simpleString: String = "int" private[spark] override def asNullable: IntegerType = this } @@ -602,7 +602,7 @@ class ShortType private() extends IntegralType { */ override def defaultSize: Int = 2 - override def simpleString = "smallint" + override def simpleString: String = "smallint" private[spark] override def asNullable: ShortType = this } @@ -632,7 +632,7 @@ class ByteType private() extends IntegralType { */ override def defaultSize: Int = 1 - override def simpleString = "tinyint" + override def simpleString: String = "tinyint" private[spark] override def asNullable: ByteType = this } @@ -696,7 +696,7 @@ case class DecimalType(precisionInfo: Option[PrecisionInfo]) extends FractionalT */ override def defaultSize: Int = 4096 - override def simpleString = precisionInfo match { + override def simpleString: String = precisionInfo match { case Some(PrecisionInfo(precision, scale)) => s"decimal($precision,$scale)" case None => "decimal(10,0)" } @@ -836,7 +836,7 @@ case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataT */ override def defaultSize: Int = 100 * elementType.defaultSize - override def simpleString = s"array<${elementType.simpleString}>" + override def simpleString: String = s"array<${elementType.simpleString}>" private[spark] override def asNullable: ArrayType = ArrayType(elementType.asNullable, containsNull = true) @@ -1065,7 +1065,7 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru */ override def defaultSize: Int = fields.map(_.dataType.defaultSize).sum - override def simpleString = { + override def simpleString: String = { val fieldTypes = fields.map(field => s"${field.name}:${field.dataType.simpleString}") s"struct<${fieldTypes.mkString(",")}>" } @@ -1142,7 +1142,7 @@ case class MapType( */ override def defaultSize: Int = 100 * (keyType.defaultSize + valueType.defaultSize) - override def simpleString = s"map<${keyType.simpleString},${valueType.simpleString}>" + override def simpleString: String = s"map<${keyType.simpleString},${valueType.simpleString}>" private[spark] override def asNullable: MapType = MapType(keyType.asNullable, valueType.asNullable, valueContainsNull = true) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 908c78a4d3f10..b7a13a1b26802 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -59,7 +59,7 @@ class Column(protected[sql] val expr: Expression) { override def toString: String = expr.prettyString - override def equals(that: Any) = that match { + override def equals(that: Any): Boolean = that match { case that: Column => that.expr.equals(this.expr) case _ => false } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 46f50708a9184..8b8f86c4127e0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -33,7 +33,7 @@ import org.apache.spark.api.java.JavaRDD import org.apache.spark.api.python.SerDeUtil import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel -import org.apache.spark.sql.catalyst.{ScalaReflection, SqlParser} +import org.apache.spark.sql.catalyst.{expressions, ScalaReflection, SqlParser} import org.apache.spark.sql.catalyst.analysis.{UnresolvedRelation, ResolvedStar} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.{JoinType, Inner} @@ -722,7 +722,7 @@ class DataFrame private[sql]( : DataFrame = { val dataType = ScalaReflection.schemaFor[B].dataType val attributes = AttributeReference(outputColumn, dataType)() :: Nil - def rowFunction(row: Row) = { + def rowFunction(row: Row): TraversableOnce[Row] = { f(row(0).asInstanceOf[A]).map(o => Row(ScalaReflection.convertToCatalyst(o, dataType))) } val generator = UserDefinedGenerator(attributes, rowFunction, apply(inputColumn).expr :: Nil) @@ -1155,7 +1155,7 @@ class DataFrame private[sql]( val gen = new JsonFactory().createGenerator(writer).setRootValueSeparator(null) new Iterator[String] { - override def hasNext = iter.hasNext + override def hasNext: Boolean = iter.hasNext override def next(): String = { JsonRDD.rowToJSON(rowSchema, gen)(iter.next()) gen.flush() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 6de46a50db20e..dc9912b52dcab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -144,7 +144,7 @@ class SQLContext(@transient val sparkContext: SparkContext) @transient protected[sql] val tlSession = new ThreadLocal[SQLSession]() { - override def initialValue = defaultSession + override def initialValue: SQLSession = defaultSession } @transient @@ -988,9 +988,9 @@ class SQLContext(@transient val sparkContext: SparkContext) val sqlContext: SQLContext = self - def codegenEnabled = self.conf.codegenEnabled + def codegenEnabled: Boolean = self.conf.codegenEnabled - def numPartitions = self.conf.numShufflePartitions + def numPartitions: Int = self.conf.numShufflePartitions def strategies: Seq[Strategy] = experimental.extraStrategies ++ ( @@ -1109,7 +1109,7 @@ class SQLContext(@transient val sparkContext: SparkContext) lazy val analyzed: LogicalPlan = analyzer(logical) lazy val withCachedData: LogicalPlan = { - assertAnalyzed + assertAnalyzed() cacheManager.useCachedData(analyzed) } lazy val optimizedPlan: LogicalPlan = optimizer(withCachedData) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index 8051df299252c..b97aaf73529a3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -61,7 +61,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { val dataType = sqlContext.parseDataType(stringDataType) - def builder(e: Seq[Expression]) = + def builder(e: Seq[Expression]): PythonUDF = PythonUDF( name, command, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala index b615eaa0dca0d..f615fb33a7c35 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala @@ -48,9 +48,9 @@ private[sql] abstract class BasicColumnAccessor[T <: DataType, JvmType]( protected def initialize() {} - def hasNext = buffer.hasRemaining + override def hasNext: Boolean = buffer.hasRemaining - def extractTo(row: MutableRow, ordinal: Int): Unit = { + override def extractTo(row: MutableRow, ordinal: Int): Unit = { extractSingle(row, ordinal) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala index d8d24a577347c..c881747751520 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala @@ -58,7 +58,7 @@ private[sql] class BasicColumnBuilder[T <: DataType, JvmType]( override def initialize( initialSize: Int, columnName: String = "", - useCompression: Boolean = false) = { + useCompression: Boolean = false): Unit = { val size = if (initialSize == 0) DEFAULT_INITIAL_BUFFER_SIZE else initialSize this.columnName = columnName @@ -73,7 +73,7 @@ private[sql] class BasicColumnBuilder[T <: DataType, JvmType]( columnType.append(row, ordinal, buffer) } - override def build() = { + override def build(): ByteBuffer = { buffer.flip().asInstanceOf[ByteBuffer] } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala index 04047b9c062be..87a6631da8300 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala @@ -76,7 +76,7 @@ private[sql] sealed trait ColumnStats extends Serializable { private[sql] class NoopColumnStats extends ColumnStats { override def gatherStats(row: Row, ordinal: Int): Unit = super.gatherStats(row, ordinal) - def collectedStatistics = Row(null, null, nullCount, count, 0L) + override def collectedStatistics: Row = Row(null, null, nullCount, count, 0L) } private[sql] class BooleanColumnStats extends ColumnStats { @@ -93,7 +93,7 @@ private[sql] class BooleanColumnStats extends ColumnStats { } } - def collectedStatistics = Row(lower, upper, nullCount, count, sizeInBytes) + override def collectedStatistics: Row = Row(lower, upper, nullCount, count, sizeInBytes) } private[sql] class ByteColumnStats extends ColumnStats { @@ -110,7 +110,7 @@ private[sql] class ByteColumnStats extends ColumnStats { } } - def collectedStatistics = Row(lower, upper, nullCount, count, sizeInBytes) + override def collectedStatistics: Row = Row(lower, upper, nullCount, count, sizeInBytes) } private[sql] class ShortColumnStats extends ColumnStats { @@ -127,7 +127,7 @@ private[sql] class ShortColumnStats extends ColumnStats { } } - def collectedStatistics = Row(lower, upper, nullCount, count, sizeInBytes) + override def collectedStatistics: Row = Row(lower, upper, nullCount, count, sizeInBytes) } private[sql] class LongColumnStats extends ColumnStats { @@ -144,7 +144,7 @@ private[sql] class LongColumnStats extends ColumnStats { } } - def collectedStatistics = Row(lower, upper, nullCount, count, sizeInBytes) + override def collectedStatistics: Row = Row(lower, upper, nullCount, count, sizeInBytes) } private[sql] class DoubleColumnStats extends ColumnStats { @@ -161,7 +161,7 @@ private[sql] class DoubleColumnStats extends ColumnStats { } } - def collectedStatistics = Row(lower, upper, nullCount, count, sizeInBytes) + override def collectedStatistics: Row = Row(lower, upper, nullCount, count, sizeInBytes) } private[sql] class FloatColumnStats extends ColumnStats { @@ -178,7 +178,7 @@ private[sql] class FloatColumnStats extends ColumnStats { } } - def collectedStatistics = Row(lower, upper, nullCount, count, sizeInBytes) + override def collectedStatistics: Row = Row(lower, upper, nullCount, count, sizeInBytes) } private[sql] class FixedDecimalColumnStats extends ColumnStats { @@ -212,7 +212,7 @@ private[sql] class IntColumnStats extends ColumnStats { } } - def collectedStatistics = Row(lower, upper, nullCount, count, sizeInBytes) + override def collectedStatistics: Row = Row(lower, upper, nullCount, count, sizeInBytes) } private[sql] class StringColumnStats extends ColumnStats { @@ -229,7 +229,7 @@ private[sql] class StringColumnStats extends ColumnStats { } } - def collectedStatistics = Row(lower, upper, nullCount, count, sizeInBytes) + override def collectedStatistics: Row = Row(lower, upper, nullCount, count, sizeInBytes) } private[sql] class DateColumnStats extends IntColumnStats @@ -248,7 +248,7 @@ private[sql] class TimestampColumnStats extends ColumnStats { } } - def collectedStatistics = Row(lower, upper, nullCount, count, sizeInBytes) + override def collectedStatistics: Row = Row(lower, upper, nullCount, count, sizeInBytes) } private[sql] class BinaryColumnStats extends ColumnStats { @@ -259,7 +259,7 @@ private[sql] class BinaryColumnStats extends ColumnStats { } } - def collectedStatistics = Row(null, null, nullCount, count, sizeInBytes) + override def collectedStatistics: Row = Row(null, null, nullCount, count, sizeInBytes) } private[sql] class GenericColumnStats extends ColumnStats { @@ -270,5 +270,5 @@ private[sql] class GenericColumnStats extends ColumnStats { } } - def collectedStatistics = Row(null, null, nullCount, count, sizeInBytes) + override def collectedStatistics: Row = Row(null, null, nullCount, count, sizeInBytes) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala index 36ea1c77e0470..c47497e0662d9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala @@ -98,7 +98,7 @@ private[sql] sealed abstract class ColumnType[T <: DataType, JvmType]( */ def clone(v: JvmType): JvmType = v - override def toString = getClass.getSimpleName.stripSuffix("$") + override def toString: String = getClass.getSimpleName.stripSuffix("$") } private[sql] abstract class NativeColumnType[T <: NativeType]( @@ -114,7 +114,7 @@ private[sql] abstract class NativeColumnType[T <: NativeType]( } private[sql] object INT extends NativeColumnType(IntegerType, 0, 4) { - def append(v: Int, buffer: ByteBuffer): Unit = { + override def append(v: Int, buffer: ByteBuffer): Unit = { buffer.putInt(v) } @@ -122,7 +122,7 @@ private[sql] object INT extends NativeColumnType(IntegerType, 0, 4) { buffer.putInt(row.getInt(ordinal)) } - def extract(buffer: ByteBuffer) = { + override def extract(buffer: ByteBuffer): Int = { buffer.getInt() } @@ -134,7 +134,7 @@ private[sql] object INT extends NativeColumnType(IntegerType, 0, 4) { row.setInt(ordinal, value) } - override def getField(row: Row, ordinal: Int) = row.getInt(ordinal) + override def getField(row: Row, ordinal: Int): Int = row.getInt(ordinal) override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { to.setInt(toOrdinal, from.getInt(fromOrdinal)) @@ -150,7 +150,7 @@ private[sql] object LONG extends NativeColumnType(LongType, 1, 8) { buffer.putLong(row.getLong(ordinal)) } - override def extract(buffer: ByteBuffer) = { + override def extract(buffer: ByteBuffer): Long = { buffer.getLong() } @@ -162,7 +162,7 @@ private[sql] object LONG extends NativeColumnType(LongType, 1, 8) { row.setLong(ordinal, value) } - override def getField(row: Row, ordinal: Int) = row.getLong(ordinal) + override def getField(row: Row, ordinal: Int): Long = row.getLong(ordinal) override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { to.setLong(toOrdinal, from.getLong(fromOrdinal)) @@ -178,7 +178,7 @@ private[sql] object FLOAT extends NativeColumnType(FloatType, 2, 4) { buffer.putFloat(row.getFloat(ordinal)) } - override def extract(buffer: ByteBuffer) = { + override def extract(buffer: ByteBuffer): Float = { buffer.getFloat() } @@ -190,7 +190,7 @@ private[sql] object FLOAT extends NativeColumnType(FloatType, 2, 4) { row.setFloat(ordinal, value) } - override def getField(row: Row, ordinal: Int) = row.getFloat(ordinal) + override def getField(row: Row, ordinal: Int): Float = row.getFloat(ordinal) override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { to.setFloat(toOrdinal, from.getFloat(fromOrdinal)) @@ -206,7 +206,7 @@ private[sql] object DOUBLE extends NativeColumnType(DoubleType, 3, 8) { buffer.putDouble(row.getDouble(ordinal)) } - override def extract(buffer: ByteBuffer) = { + override def extract(buffer: ByteBuffer): Double = { buffer.getDouble() } @@ -218,7 +218,7 @@ private[sql] object DOUBLE extends NativeColumnType(DoubleType, 3, 8) { row.setDouble(ordinal, value) } - override def getField(row: Row, ordinal: Int) = row.getDouble(ordinal) + override def getField(row: Row, ordinal: Int): Double = row.getDouble(ordinal) override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { to.setDouble(toOrdinal, from.getDouble(fromOrdinal)) @@ -234,7 +234,7 @@ private[sql] object BOOLEAN extends NativeColumnType(BooleanType, 4, 1) { buffer.put(if (row.getBoolean(ordinal)) 1: Byte else 0: Byte) } - override def extract(buffer: ByteBuffer) = buffer.get() == 1 + override def extract(buffer: ByteBuffer): Boolean = buffer.get() == 1 override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = { row.setBoolean(ordinal, buffer.get() == 1) @@ -244,7 +244,7 @@ private[sql] object BOOLEAN extends NativeColumnType(BooleanType, 4, 1) { row.setBoolean(ordinal, value) } - override def getField(row: Row, ordinal: Int) = row.getBoolean(ordinal) + override def getField(row: Row, ordinal: Int): Boolean = row.getBoolean(ordinal) override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { to.setBoolean(toOrdinal, from.getBoolean(fromOrdinal)) @@ -260,7 +260,7 @@ private[sql] object BYTE extends NativeColumnType(ByteType, 5, 1) { buffer.put(row.getByte(ordinal)) } - override def extract(buffer: ByteBuffer) = { + override def extract(buffer: ByteBuffer): Byte = { buffer.get() } @@ -272,7 +272,7 @@ private[sql] object BYTE extends NativeColumnType(ByteType, 5, 1) { row.setByte(ordinal, value) } - override def getField(row: Row, ordinal: Int) = row.getByte(ordinal) + override def getField(row: Row, ordinal: Int): Byte = row.getByte(ordinal) override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { to.setByte(toOrdinal, from.getByte(fromOrdinal)) @@ -288,7 +288,7 @@ private[sql] object SHORT extends NativeColumnType(ShortType, 6, 2) { buffer.putShort(row.getShort(ordinal)) } - override def extract(buffer: ByteBuffer) = { + override def extract(buffer: ByteBuffer): Short = { buffer.getShort() } @@ -300,7 +300,7 @@ private[sql] object SHORT extends NativeColumnType(ShortType, 6, 2) { row.setShort(ordinal, value) } - override def getField(row: Row, ordinal: Int) = row.getShort(ordinal) + override def getField(row: Row, ordinal: Int): Short = row.getShort(ordinal) override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { to.setShort(toOrdinal, from.getShort(fromOrdinal)) @@ -317,7 +317,7 @@ private[sql] object STRING extends NativeColumnType(StringType, 7, 8) { buffer.putInt(stringBytes.length).put(stringBytes, 0, stringBytes.length) } - override def extract(buffer: ByteBuffer) = { + override def extract(buffer: ByteBuffer): String = { val length = buffer.getInt() val stringBytes = new Array[Byte](length) buffer.get(stringBytes, 0, length) @@ -328,7 +328,7 @@ private[sql] object STRING extends NativeColumnType(StringType, 7, 8) { row.setString(ordinal, value) } - override def getField(row: Row, ordinal: Int) = row.getString(ordinal) + override def getField(row: Row, ordinal: Int): String = row.getString(ordinal) override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { to.setString(toOrdinal, from.getString(fromOrdinal)) @@ -336,7 +336,7 @@ private[sql] object STRING extends NativeColumnType(StringType, 7, 8) { } private[sql] object DATE extends NativeColumnType(DateType, 8, 4) { - override def extract(buffer: ByteBuffer) = { + override def extract(buffer: ByteBuffer): Int = { buffer.getInt } @@ -344,7 +344,7 @@ private[sql] object DATE extends NativeColumnType(DateType, 8, 4) { buffer.putInt(v) } - override def getField(row: Row, ordinal: Int) = { + override def getField(row: Row, ordinal: Int): Int = { row(ordinal).asInstanceOf[Int] } @@ -354,7 +354,7 @@ private[sql] object DATE extends NativeColumnType(DateType, 8, 4) { } private[sql] object TIMESTAMP extends NativeColumnType(TimestampType, 9, 12) { - override def extract(buffer: ByteBuffer) = { + override def extract(buffer: ByteBuffer): Timestamp = { val timestamp = new Timestamp(buffer.getLong()) timestamp.setNanos(buffer.getInt()) timestamp @@ -364,7 +364,7 @@ private[sql] object TIMESTAMP extends NativeColumnType(TimestampType, 9, 12) { buffer.putLong(v.getTime).putInt(v.getNanos) } - override def getField(row: Row, ordinal: Int) = { + override def getField(row: Row, ordinal: Int): Timestamp = { row(ordinal).asInstanceOf[Timestamp] } @@ -405,7 +405,7 @@ private[sql] sealed abstract class ByteArrayColumnType[T <: DataType]( defaultSize: Int) extends ColumnType[T, Array[Byte]](typeId, defaultSize) { - override def actualSize(row: Row, ordinal: Int) = { + override def actualSize(row: Row, ordinal: Int): Int = { getField(row, ordinal).length + 4 } @@ -413,7 +413,7 @@ private[sql] sealed abstract class ByteArrayColumnType[T <: DataType]( buffer.putInt(v.length).put(v, 0, v.length) } - override def extract(buffer: ByteBuffer) = { + override def extract(buffer: ByteBuffer): Array[Byte] = { val length = buffer.getInt() val bytes = new Array[Byte](length) buffer.get(bytes, 0, length) @@ -426,7 +426,9 @@ private[sql] object BINARY extends ByteArrayColumnType[BinaryType.type](11, 16) row(ordinal) = value } - override def getField(row: Row, ordinal: Int) = row(ordinal).asInstanceOf[Array[Byte]] + override def getField(row: Row, ordinal: Int): Array[Byte] = { + row(ordinal).asInstanceOf[Array[Byte]] + } } // Used to process generic objects (all types other than those listed above). Objects should be @@ -437,7 +439,9 @@ private[sql] object GENERIC extends ByteArrayColumnType[DataType](12, 16) { row(ordinal) = SparkSqlSerializer.deserialize[Any](value) } - override def getField(row: Row, ordinal: Int) = SparkSqlSerializer.serialize(row(ordinal)) + override def getField(row: Row, ordinal: Int): Array[Byte] = { + SparkSqlSerializer.serialize(row(ordinal)) + } } private[sql] object ColumnType { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala index 387faee12b3cd..6eee0c86d6a1c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala @@ -19,6 +19,9 @@ package org.apache.spark.sql.columnar import java.nio.ByteBuffer +import org.apache.spark.Accumulator +import org.apache.spark.sql.catalyst.expressions + import scala.collection.mutable.ArrayBuffer import org.apache.spark.rdd.RDD @@ -77,20 +80,23 @@ private[sql] case class InMemoryRelation( _statistics } - override def statistics = if (_statistics == null) { - if (batchStats.value.isEmpty) { - // Underlying columnar RDD hasn't been materialized, no useful statistics information - // available, return the default statistics. - Statistics(sizeInBytes = child.sqlContext.conf.defaultSizeInBytes) + override def statistics: Statistics = { + if (_statistics == null) { + if (batchStats.value.isEmpty) { + // Underlying columnar RDD hasn't been materialized, no useful statistics information + // available, return the default statistics. + Statistics(sizeInBytes = child.sqlContext.conf.defaultSizeInBytes) + } else { + // Underlying columnar RDD has been materialized, required information has also been + // collected via the `batchStats` accumulator, compute the final statistics, + // and update `_statistics`. + _statistics = Statistics(sizeInBytes = computeSizeInBytes) + _statistics + } } else { - // Underlying columnar RDD has been materialized, required information has also been collected - // via the `batchStats` accumulator, compute the final statistics, and update `_statistics`. - _statistics = Statistics(sizeInBytes = computeSizeInBytes) + // Pre-computed statistics _statistics } - } else { - // Pre-computed statistics - _statistics } // If the cached column buffers were not passed in, we calculate them in the constructor. @@ -99,7 +105,7 @@ private[sql] case class InMemoryRelation( buildBuffers() } - def recache() = { + def recache(): Unit = { _cachedColumnBuffers.unpersist() _cachedColumnBuffers = null buildBuffers() @@ -109,7 +115,7 @@ private[sql] case class InMemoryRelation( val output = child.output val cached = child.execute().mapPartitions { rowIterator => new Iterator[CachedBatch] { - def next() = { + def next(): CachedBatch = { val columnBuilders = output.map { attribute => val columnType = ColumnType(attribute.dataType) val initialBufferSize = columnType.defaultSize * batchSize @@ -144,7 +150,7 @@ private[sql] case class InMemoryRelation( CachedBatch(columnBuilders.map(_.build().array()), stats) } - def hasNext = rowIterator.hasNext + def hasNext: Boolean = rowIterator.hasNext } }.persist(storageLevel) @@ -158,9 +164,9 @@ private[sql] case class InMemoryRelation( _cachedColumnBuffers, statisticsToBePropagated) } - override def children = Seq.empty + override def children: Seq[LogicalPlan] = Seq.empty - override def newInstance() = { + override def newInstance(): this.type = { new InMemoryRelation( output.map(_.newInstance()), useCompression, @@ -172,7 +178,7 @@ private[sql] case class InMemoryRelation( statisticsToBePropagated).asInstanceOf[this.type] } - def cachedColumnBuffers = _cachedColumnBuffers + def cachedColumnBuffers: RDD[CachedBatch] = _cachedColumnBuffers override protected def otherCopyArgs: Seq[AnyRef] = Seq(_cachedColumnBuffers, statisticsToBePropagated) @@ -220,7 +226,7 @@ private[sql] case class InMemoryColumnarTableScan( case IsNotNull(a: Attribute) => statsFor(a).count - statsFor(a).nullCount > 0 } - val partitionFilters = { + val partitionFilters: Seq[Expression] = { predicates.flatMap { p => val filter = buildFilter.lift(p) val boundFilter = @@ -239,12 +245,12 @@ private[sql] case class InMemoryColumnarTableScan( } // Accumulators used for testing purposes - val readPartitions = sparkContext.accumulator(0) - val readBatches = sparkContext.accumulator(0) + val readPartitions: Accumulator[Int] = sparkContext.accumulator(0) + val readBatches: Accumulator[Int] = sparkContext.accumulator(0) private val inMemoryPartitionPruningEnabled = sqlContext.conf.inMemoryPartitionPruning - override def execute() = { + override def execute(): RDD[Row] = { readPartitions.setValue(0) readBatches.setValue(0) @@ -271,7 +277,7 @@ private[sql] case class InMemoryColumnarTableScan( val nextRow = new SpecificMutableRow(requestedColumnDataTypes) - def cachedBatchesToRows(cacheBatches: Iterator[CachedBatch]) = { + def cachedBatchesToRows(cacheBatches: Iterator[CachedBatch]): Iterator[Row] = { val rows = cacheBatches.flatMap { cachedBatch => // Build column accessors val columnAccessors = requestedColumnIndices.map { batchColumnIndex => @@ -283,7 +289,7 @@ private[sql] case class InMemoryColumnarTableScan( // Extract rows via column accessors new Iterator[Row] { private[this] val rowLen = nextRow.length - override def next() = { + override def next(): Row = { var i = 0 while (i < rowLen) { columnAccessors(i).extractTo(nextRow, i) @@ -292,7 +298,7 @@ private[sql] case class InMemoryColumnarTableScan( nextRow } - override def hasNext = columnAccessors(0).hasNext + override def hasNext: Boolean = columnAccessors(0).hasNext } } @@ -308,7 +314,7 @@ private[sql] case class InMemoryColumnarTableScan( if (inMemoryPartitionPruningEnabled) { cachedBatchIterator.filter { cachedBatch => if (!partitionFilter(cachedBatch.stats)) { - def statsString = relation.partitionStatistics.schema + def statsString: String = relation.partitionStatistics.schema .zip(cachedBatch.stats.toSeq) .map { case (a, s) => s"${a.name}: $s" } .mkString(", ") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnAccessor.scala index 965782a40031b..4d35650d4b1eb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnAccessor.scala @@ -55,5 +55,5 @@ private[sql] trait NullableColumnAccessor extends ColumnAccessor { pos += 1 } - abstract override def hasNext = seenNulls < nullCount || super.hasNext + abstract override def hasNext: Boolean = seenNulls < nullCount || super.hasNext } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnAccessor.scala index 7dff9deac8dc0..d0b602a834dfe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnAccessor.scala @@ -26,12 +26,12 @@ private[sql] trait CompressibleColumnAccessor[T <: NativeType] extends ColumnAcc private var decoder: Decoder[T] = _ - abstract override protected def initialize() = { + abstract override protected def initialize(): Unit = { super.initialize() decoder = CompressionScheme(underlyingBuffer.getInt()).decoder(buffer, columnType) } - abstract override def hasNext = super.hasNext || decoder.hasNext + abstract override def hasNext: Boolean = super.hasNext || decoder.hasNext override def extractSingle(row: MutableRow, ordinal: Int): Unit = { decoder.next(row, ordinal) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala index aead768ecdf0a..b9cfc5df550d1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala @@ -81,7 +81,7 @@ private[sql] trait CompressibleColumnBuilder[T <: NativeType] } } - override def build() = { + override def build(): ByteBuffer = { val nonNullBuffer = buildNonNulls() val typeId = nonNullBuffer.getInt() val encoder: Encoder[T] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala index 68a5b1de7691b..8727d71c48bb7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala @@ -33,22 +33,23 @@ import org.apache.spark.util.Utils private[sql] case object PassThrough extends CompressionScheme { override val typeId = 0 - override def supports(columnType: ColumnType[_, _]) = true + override def supports(columnType: ColumnType[_, _]): Boolean = true - override def encoder[T <: NativeType](columnType: NativeColumnType[T]) = { + override def encoder[T <: NativeType](columnType: NativeColumnType[T]): Encoder[T] = { new this.Encoder[T](columnType) } - override def decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T]) = { + override def decoder[T <: NativeType]( + buffer: ByteBuffer, columnType: NativeColumnType[T]): Decoder[T] = { new this.Decoder(buffer, columnType) } class Encoder[T <: NativeType](columnType: NativeColumnType[T]) extends compression.Encoder[T] { - override def uncompressedSize = 0 + override def uncompressedSize: Int = 0 - override def compressedSize = 0 + override def compressedSize: Int = 0 - override def compress(from: ByteBuffer, to: ByteBuffer) = { + override def compress(from: ByteBuffer, to: ByteBuffer): ByteBuffer = { // Writes compression type ID and copies raw contents to.putInt(PassThrough.typeId).put(from).rewind() to @@ -62,22 +63,23 @@ private[sql] case object PassThrough extends CompressionScheme { columnType.extract(buffer, row, ordinal) } - override def hasNext = buffer.hasRemaining + override def hasNext: Boolean = buffer.hasRemaining } } private[sql] case object RunLengthEncoding extends CompressionScheme { override val typeId = 1 - override def encoder[T <: NativeType](columnType: NativeColumnType[T]) = { + override def encoder[T <: NativeType](columnType: NativeColumnType[T]): Encoder[T] = { new this.Encoder[T](columnType) } - override def decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T]) = { + override def decoder[T <: NativeType]( + buffer: ByteBuffer, columnType: NativeColumnType[T]): Decoder[T] = { new this.Decoder(buffer, columnType) } - override def supports(columnType: ColumnType[_, _]) = columnType match { + override def supports(columnType: ColumnType[_, _]): Boolean = columnType match { case INT | LONG | SHORT | BYTE | STRING | BOOLEAN => true case _ => false } @@ -90,9 +92,9 @@ private[sql] case object RunLengthEncoding extends CompressionScheme { private val lastValue = new SpecificMutableRow(Seq(columnType.dataType)) private var lastRun = 0 - override def uncompressedSize = _uncompressedSize + override def uncompressedSize: Int = _uncompressedSize - override def compressedSize = _compressedSize + override def compressedSize: Int = _compressedSize override def gatherCompressibilityStats(row: Row, ordinal: Int): Unit = { val value = columnType.getField(row, ordinal) @@ -114,7 +116,7 @@ private[sql] case object RunLengthEncoding extends CompressionScheme { } } - override def compress(from: ByteBuffer, to: ByteBuffer) = { + override def compress(from: ByteBuffer, to: ByteBuffer): ByteBuffer = { to.putInt(RunLengthEncoding.typeId) if (from.hasRemaining) { @@ -169,7 +171,7 @@ private[sql] case object RunLengthEncoding extends CompressionScheme { columnType.setField(row, ordinal, currentValue) } - override def hasNext = valueCount < run || buffer.hasRemaining + override def hasNext: Boolean = valueCount < run || buffer.hasRemaining } } @@ -179,15 +181,16 @@ private[sql] case object DictionaryEncoding extends CompressionScheme { // 32K unique values allowed val MAX_DICT_SIZE = Short.MaxValue - override def decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T]) = { + override def decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T]) + : Decoder[T] = { new this.Decoder(buffer, columnType) } - override def encoder[T <: NativeType](columnType: NativeColumnType[T]) = { + override def encoder[T <: NativeType](columnType: NativeColumnType[T]): Encoder[T] = { new this.Encoder[T](columnType) } - override def supports(columnType: ColumnType[_, _]) = columnType match { + override def supports(columnType: ColumnType[_, _]): Boolean = columnType match { case INT | LONG | STRING => true case _ => false } @@ -237,7 +240,7 @@ private[sql] case object DictionaryEncoding extends CompressionScheme { } } - override def compress(from: ByteBuffer, to: ByteBuffer) = { + override def compress(from: ByteBuffer, to: ByteBuffer): ByteBuffer = { if (overflow) { throw new IllegalStateException( "Dictionary encoding should not be used because of dictionary overflow.") @@ -260,9 +263,9 @@ private[sql] case object DictionaryEncoding extends CompressionScheme { to } - override def uncompressedSize = _uncompressedSize + override def uncompressedSize: Int = _uncompressedSize - override def compressedSize = if (overflow) Int.MaxValue else dictionarySize + count * 2 + override def compressedSize: Int = if (overflow) Int.MaxValue else dictionarySize + count * 2 } class Decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T]) @@ -284,7 +287,7 @@ private[sql] case object DictionaryEncoding extends CompressionScheme { columnType.setField(row, ordinal, dictionary(buffer.getShort())) } - override def hasNext = buffer.hasRemaining + override def hasNext: Boolean = buffer.hasRemaining } } @@ -293,15 +296,16 @@ private[sql] case object BooleanBitSet extends CompressionScheme { val BITS_PER_LONG = 64 - override def decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T]) = { + override def decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T]) + : compression.Decoder[T] = { new this.Decoder(buffer).asInstanceOf[compression.Decoder[T]] } - override def encoder[T <: NativeType](columnType: NativeColumnType[T]) = { + override def encoder[T <: NativeType](columnType: NativeColumnType[T]): compression.Encoder[T] = { (new this.Encoder).asInstanceOf[compression.Encoder[T]] } - override def supports(columnType: ColumnType[_, _]) = columnType == BOOLEAN + override def supports(columnType: ColumnType[_, _]): Boolean = columnType == BOOLEAN class Encoder extends compression.Encoder[BooleanType.type] { private var _uncompressedSize = 0 @@ -310,7 +314,7 @@ private[sql] case object BooleanBitSet extends CompressionScheme { _uncompressedSize += BOOLEAN.defaultSize } - override def compress(from: ByteBuffer, to: ByteBuffer) = { + override def compress(from: ByteBuffer, to: ByteBuffer): ByteBuffer = { to.putInt(BooleanBitSet.typeId) // Total element count (1 byte per Boolean value) .putInt(from.remaining) @@ -347,9 +351,9 @@ private[sql] case object BooleanBitSet extends CompressionScheme { to } - override def uncompressedSize = _uncompressedSize + override def uncompressedSize: Int = _uncompressedSize - override def compressedSize = { + override def compressedSize: Int = { val extra = if (_uncompressedSize % BITS_PER_LONG == 0) 0 else 1 (_uncompressedSize / BITS_PER_LONG + extra) * 8 + 4 } @@ -380,22 +384,23 @@ private[sql] case object BooleanBitSet extends CompressionScheme { private[sql] case object IntDelta extends CompressionScheme { override def typeId: Int = 4 - override def decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T]) = { + override def decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T]) + : compression.Decoder[T] = { new Decoder(buffer, INT).asInstanceOf[compression.Decoder[T]] } - override def encoder[T <: NativeType](columnType: NativeColumnType[T]) = { + override def encoder[T <: NativeType](columnType: NativeColumnType[T]): compression.Encoder[T] = { (new Encoder).asInstanceOf[compression.Encoder[T]] } - override def supports(columnType: ColumnType[_, _]) = columnType == INT + override def supports(columnType: ColumnType[_, _]): Boolean = columnType == INT class Encoder extends compression.Encoder[IntegerType.type] { protected var _compressedSize: Int = 0 protected var _uncompressedSize: Int = 0 - override def compressedSize = _compressedSize - override def uncompressedSize = _uncompressedSize + override def compressedSize: Int = _compressedSize + override def uncompressedSize: Int = _uncompressedSize private var prevValue: Int = _ @@ -459,22 +464,23 @@ private[sql] case object IntDelta extends CompressionScheme { private[sql] case object LongDelta extends CompressionScheme { override def typeId: Int = 5 - override def decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T]) = { + override def decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T]) + : compression.Decoder[T] = { new Decoder(buffer, LONG).asInstanceOf[compression.Decoder[T]] } - override def encoder[T <: NativeType](columnType: NativeColumnType[T]) = { + override def encoder[T <: NativeType](columnType: NativeColumnType[T]): compression.Encoder[T] = { (new Encoder).asInstanceOf[compression.Encoder[T]] } - override def supports(columnType: ColumnType[_, _]) = columnType == LONG + override def supports(columnType: ColumnType[_, _]): Boolean = columnType == LONG class Encoder extends compression.Encoder[LongType.type] { protected var _compressedSize: Int = 0 protected var _uncompressedSize: Int = 0 - override def compressedSize = _compressedSize - override def uncompressedSize = _uncompressedSize + override def compressedSize: Int = _compressedSize + override def uncompressedSize: Int = _uncompressedSize private var prevValue: Long = _ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala index ad44a01d0e164..18b1ba4c5c4b9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala @@ -21,6 +21,7 @@ import java.util.HashMap import org.apache.spark.annotation.DeveloperApi import org.apache.spark.SparkContext +import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ @@ -45,7 +46,7 @@ case class Aggregate( child: SparkPlan) extends UnaryNode { - override def requiredChildDistribution = + override def requiredChildDistribution: List[Distribution] = { if (partial) { UnspecifiedDistribution :: Nil } else { @@ -55,8 +56,9 @@ case class Aggregate( ClusteredDistribution(groupingExpressions) :: Nil } } + } - override def output = aggregateExpressions.map(_.toAttribute) + override def output: Seq[Attribute] = aggregateExpressions.map(_.toAttribute) /** * An aggregate that needs to be computed for each row in a group. @@ -119,7 +121,7 @@ case class Aggregate( } } - override def execute() = attachTree(this, "execute") { + override def execute(): RDD[Row] = attachTree(this, "execute") { if (groupingExpressions.isEmpty) { child.execute().mapPartitions { iter => val buffer = newAggregateBuffer() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 7c0b72aab448e..437408d30bfd2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -19,11 +19,12 @@ package org.apache.spark.sql.execution import org.apache.spark.annotation.DeveloperApi import org.apache.spark.shuffle.sort.SortShuffleManager +import org.apache.spark.sql.catalyst.expressions import org.apache.spark.{SparkEnv, HashPartitioner, RangePartitioner, SparkConf} -import org.apache.spark.rdd.ShuffledRDD +import org.apache.spark.rdd.{RDD, ShuffledRDD} import org.apache.spark.sql.{SQLContext, Row} import org.apache.spark.sql.catalyst.errors.attachTree -import org.apache.spark.sql.catalyst.expressions.RowOrdering +import org.apache.spark.sql.catalyst.expressions.{Attribute, RowOrdering} import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.util.MutablePair @@ -34,9 +35,9 @@ import org.apache.spark.util.MutablePair @DeveloperApi case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends UnaryNode { - override def outputPartitioning = newPartitioning + override def outputPartitioning: Partitioning = newPartitioning - override def output = child.output + override def output: Seq[Attribute] = child.output /** We must copy rows when sort based shuffle is on */ protected def sortBasedShuffleOn = SparkEnv.get.shuffleManager.isInstanceOf[SortShuffleManager] @@ -44,7 +45,7 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una private val bypassMergeThreshold = child.sqlContext.sparkContext.conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200) - override def execute() = attachTree(this , "execute") { + override def execute(): RDD[Row] = attachTree(this , "execute") { newPartitioning match { case HashPartitioning(expressions, numPartitions) => // TODO: Eliminate redundant expressions in grouping key and value. @@ -123,13 +124,13 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una */ private[sql] case class AddExchange(sqlContext: SQLContext) extends Rule[SparkPlan] { // TODO: Determine the number of partitions. - def numPartitions = sqlContext.conf.numShufflePartitions + def numPartitions: Int = sqlContext.conf.numShufflePartitions def apply(plan: SparkPlan): SparkPlan = plan.transformUp { case operator: SparkPlan => // Check if every child's outputPartitioning satisfies the corresponding // required data distribution. - def meetsRequirements = + def meetsRequirements: Boolean = !operator.requiredChildDistribution.zip(operator.children).map { case (required, child) => val valid = child.outputPartitioning.satisfies(required) @@ -147,7 +148,7 @@ private[sql] case class AddExchange(sqlContext: SQLContext) extends Rule[SparkPl // datasets are both clustered by "a", but these two outputPartitionings are not // compatible. // TODO: ASSUMES TRANSITIVITY? - def compatible = + def compatible: Boolean = !operator.children .map(_.outputPartitioning) .sliding(2) @@ -158,7 +159,7 @@ private[sql] case class AddExchange(sqlContext: SQLContext) extends Rule[SparkPl // Check if the partitioning we want to ensure is the same as the child's output // partitioning. If so, we do not need to add the Exchange operator. - def addExchangeIfNecessary(partitioning: Partitioning, child: SparkPlan) = + def addExchangeIfNecessary(partitioning: Partitioning, child: SparkPlan): SparkPlan = if (child.outputPartitioning != partitioning) Exchange(partitioning, child) else child if (meetsRequirements && compatible) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index 248dc1512b4d3..d8955725e59b1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -26,6 +26,8 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericMutableRow} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics} import org.apache.spark.sql.types.StructType +import scala.collection.immutable + /** * :: DeveloperApi :: */ @@ -58,17 +60,17 @@ object RDDConversions { case class LogicalRDD(output: Seq[Attribute], rdd: RDD[Row])(sqlContext: SQLContext) extends LogicalPlan with MultiInstanceRelation { - override def children = Nil + override def children: Seq[LogicalPlan] = Nil - override def newInstance() = + override def newInstance(): LogicalRDD.this.type = LogicalRDD(output.map(_.newInstance()), rdd)(sqlContext).asInstanceOf[this.type] - override def sameResult(plan: LogicalPlan) = plan match { + override def sameResult(plan: LogicalPlan): Boolean = plan match { case LogicalRDD(_, otherRDD) => rdd.id == otherRDD.id case _ => false } - @transient override lazy val statistics = Statistics( + @transient override lazy val statistics: Statistics = Statistics( // TODO: Instead of returning a default value here, find a way to return a meaningful size // estimate for RDDs. See PR 1238 for more discussions. sizeInBytes = BigInt(sqlContext.conf.defaultSizeInBytes) @@ -77,24 +79,24 @@ case class LogicalRDD(output: Seq[Attribute], rdd: RDD[Row])(sqlContext: SQLCont /** Physical plan node for scanning data from an RDD. */ case class PhysicalRDD(output: Seq[Attribute], rdd: RDD[Row]) extends LeafNode { - override def execute() = rdd + override def execute(): RDD[Row] = rdd } /** Logical plan node for scanning data from a local collection. */ case class LogicalLocalTable(output: Seq[Attribute], rows: Seq[Row])(sqlContext: SQLContext) extends LogicalPlan with MultiInstanceRelation { - override def children = Nil + override def children: Seq[LogicalPlan] = Nil - override def newInstance() = + override def newInstance(): this.type = LogicalLocalTable(output.map(_.newInstance()), rows)(sqlContext).asInstanceOf[this.type] - override def sameResult(plan: LogicalPlan) = plan match { + override def sameResult(plan: LogicalPlan): Boolean = plan match { case LogicalRDD(_, otherRDD) => rows == rows case _ => false } - @transient override lazy val statistics = Statistics( + @transient override lazy val statistics: Statistics = Statistics( // TODO: Improve the statistics estimation. // This is made small enough so it can be broadcasted. sizeInBytes = sqlContext.conf.autoBroadcastJoinThreshold - 1 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala index 95172420608f9..575849481faad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.rdd.RDD import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ @@ -42,7 +43,7 @@ case class Expand( // as UNKNOWN partitioning override def outputPartitioning: Partitioning = UnknownPartitioning(0) - override def execute() = attachTree(this, "execute") { + override def execute(): RDD[Row] = attachTree(this, "execute") { child.execute().mapPartitions { iter => // TODO Move out projection objects creation and transfer to // workers via closure. However we can't assume the Projection @@ -55,7 +56,7 @@ case class Expand( private[this] var idx = -1 // -1 means the initial state private[this] var input: Row = _ - override final def hasNext = (-1 < idx && idx < groups.length) || iter.hasNext + override final def hasNext: Boolean = (-1 < idx && idx < groups.length) || iter.hasNext override final def next(): Row = { if (idx <= 0) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala index 38877c28de3a8..12271048bb39c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions._ /** @@ -54,7 +55,7 @@ case class Generate( val boundGenerator = BindReferences.bindReference(generator, child.output) - override def execute() = { + override def execute(): RDD[Row] = { if (join) { child.execute().mapPartitions { iter => val nullValues = Seq.fill(generator.output.size)(Literal(null)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala index 4abe26fe4afc6..89682d25ca7dc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.trees._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ @@ -49,7 +50,7 @@ case class GeneratedAggregate( child: SparkPlan) extends UnaryNode { - override def requiredChildDistribution = + override def requiredChildDistribution: Seq[Distribution] = if (partial) { UnspecifiedDistribution :: Nil } else { @@ -60,9 +61,9 @@ case class GeneratedAggregate( } } - override def output = aggregateExpressions.map(_.toAttribute) + override def output: Seq[Attribute] = aggregateExpressions.map(_.toAttribute) - override def execute() = { + override def execute(): RDD[Row] = { val aggregatesToCompute = aggregateExpressions.flatMap { a => a.collect { case agg: AggregateExpression => agg} } @@ -271,9 +272,9 @@ case class GeneratedAggregate( private[this] val resultIterator = buffers.entrySet.iterator() private[this] val resultProjection = resultProjectionBuilder() - def hasNext = resultIterator.hasNext + def hasNext: Boolean = resultIterator.hasNext - def next() = { + def next(): Row = { val currentGroup = resultIterator.next() resultProjection(joinedRow(currentGroup.getKey, currentGroup.getValue)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala index d3a18b37d52b9..5bd699a2fa949 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution +import org.apache.spark.rdd.RDD import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.expressions.Attribute @@ -29,11 +30,11 @@ case class LocalTableScan(output: Seq[Attribute], rows: Seq[Row]) extends LeafNo private lazy val rdd = sqlContext.sparkContext.parallelize(rows) - override def execute() = rdd + override def execute(): RDD[Row] = rdd - override def executeCollect() = + override def executeCollect(): Array[Row] = rows.map(ScalaReflection.convertRowToScala(_, schema)).toArray - override def executeTake(limit: Int) = + override def executeTake(limit: Int): Array[Row] = rows.map(ScalaReflection.convertRowToScala(_, schema)).take(limit).toArray } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 052766c20abc2..d239637cd4b4e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -67,6 +67,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ // TODO: Move to `DistributedPlan` /** Specifies how data is partitioned across different nodes in the cluster. */ def outputPartitioning: Partitioning = UnknownPartitioning(0) // TODO: WRONG WIDTH! + /** Specifies any partition requirements on the input data for this operator. */ def requiredChildDistribution: Seq[Distribution] = Seq.fill(children.size)(UnspecifiedDistribution) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala index 30564e14fa896..c4534fd5f67e4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala @@ -74,7 +74,7 @@ private[execution] class KryoResourcePool(size: Int) new KryoSerializer(sparkConf) } - def newInstance() = ser.newInstance() + def newInstance(): SerializerInstance = ser.newInstance() } private[sql] object SparkSqlSerializer { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 5281c7502556a..2b581152e5f77 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -154,7 +154,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case _ => Nil } - def canBeCodeGened(aggs: Seq[AggregateExpression]) = !aggs.exists { + def canBeCodeGened(aggs: Seq[AggregateExpression]): Boolean = !aggs.exists { case _: Sum | _: Count | _: Max | _: CombineSetsAndCount => false // The generated set implementation is pretty limited ATM. case CollectHashSet(exprs) if exprs.size == 1 && @@ -162,7 +162,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case _ => true } - def allAggregates(exprs: Seq[Expression]) = + def allAggregates(exprs: Seq[Expression]): Seq[AggregateExpression] = exprs.flatMap(_.collect { case a: AggregateExpression => a }) } @@ -257,7 +257,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { // Can we automate these 'pass through' operations? object BasicOperators extends Strategy { - def numPartitions = self.numPartitions + def numPartitions: Int = self.numPartitions def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case r: RunnableCommand => ExecutedCommand(r) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 710268584cff1..20c9bc3e75542 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -24,7 +24,7 @@ import org.apache.spark.shuffle.sort.SortShuffleManager import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, OrderedDistribution, SinglePartition, UnspecifiedDistribution} +import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.util.MutablePair import org.apache.spark.util.collection.ExternalSorter @@ -33,11 +33,11 @@ import org.apache.spark.util.collection.ExternalSorter */ @DeveloperApi case class Project(projectList: Seq[NamedExpression], child: SparkPlan) extends UnaryNode { - override def output = projectList.map(_.toAttribute) + override def output: Seq[Attribute] = projectList.map(_.toAttribute) @transient lazy val buildProjection = newMutableProjection(projectList, child.output) - override def execute() = child.execute().mapPartitions { iter => + override def execute(): RDD[Row] = child.execute().mapPartitions { iter => val resuableProjection = buildProjection() iter.map(resuableProjection) } @@ -48,11 +48,11 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan) extends */ @DeveloperApi case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode { - override def output = child.output + override def output: Seq[Attribute] = child.output - @transient lazy val conditionEvaluator = newPredicate(condition, child.output) + @transient lazy val conditionEvaluator: (Row) => Boolean = newPredicate(condition, child.output) - override def execute() = child.execute().mapPartitions { iter => + override def execute(): RDD[Row] = child.execute().mapPartitions { iter => iter.filter(conditionEvaluator) } } @@ -64,10 +64,12 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode { case class Sample(fraction: Double, withReplacement: Boolean, seed: Long, child: SparkPlan) extends UnaryNode { - override def output = child.output + override def output: Seq[Attribute] = child.output // TODO: How to pick seed? - override def execute() = child.execute().map(_.copy()).sample(withReplacement, fraction, seed) + override def execute(): RDD[Row] = { + child.execute().map(_.copy()).sample(withReplacement, fraction, seed) + } } /** @@ -76,8 +78,8 @@ case class Sample(fraction: Double, withReplacement: Boolean, seed: Long, child: @DeveloperApi case class Union(children: Seq[SparkPlan]) extends SparkPlan { // TODO: attributes output by union should be distinct for nullability purposes - override def output = children.head.output - override def execute() = sparkContext.union(children.map(_.execute())) + override def output: Seq[Attribute] = children.head.output + override def execute(): RDD[Row] = sparkContext.union(children.map(_.execute())) } /** @@ -97,12 +99,12 @@ case class Limit(limit: Int, child: SparkPlan) /** We must copy rows when sort based shuffle is on */ private def sortBasedShuffleOn = SparkEnv.get.shuffleManager.isInstanceOf[SortShuffleManager] - override def output = child.output - override def outputPartitioning = SinglePartition + override def output: Seq[Attribute] = child.output + override def outputPartitioning: Partitioning = SinglePartition override def executeCollect(): Array[Row] = child.executeTake(limit) - override def execute() = { + override def execute(): RDD[Row] = { val rdd: RDD[_ <: Product2[Boolean, Row]] = if (sortBasedShuffleOn) { child.execute().mapPartitions { iter => iter.take(limit).map(row => (false, row.copy())) @@ -129,20 +131,21 @@ case class Limit(limit: Int, child: SparkPlan) @DeveloperApi case class TakeOrdered(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan) extends UnaryNode { - override def output = child.output - override def outputPartitioning = SinglePartition + override def output: Seq[Attribute] = child.output + + override def outputPartitioning: Partitioning = SinglePartition - val ord = new RowOrdering(sortOrder, child.output) + private val ord: RowOrdering = new RowOrdering(sortOrder, child.output) - private def collectData() = child.execute().map(_.copy()).takeOrdered(limit)(ord) + private def collectData(): Array[Row] = child.execute().map(_.copy()).takeOrdered(limit)(ord) // TODO: Is this copying for no reason? - override def executeCollect() = + override def executeCollect(): Array[Row] = collectData().map(ScalaReflection.convertRowToScala(_, this.schema)) // TODO: Terminal split should be implemented differently from non-terminal split. // TODO: Pick num splits based on |limit|. - override def execute() = sparkContext.makeRDD(collectData(), 1) + override def execute(): RDD[Row] = sparkContext.makeRDD(collectData(), 1) } /** @@ -157,17 +160,17 @@ case class Sort( global: Boolean, child: SparkPlan) extends UnaryNode { - override def requiredChildDistribution = + override def requiredChildDistribution: Seq[Distribution] = if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil - override def execute() = attachTree(this, "sort") { + override def execute(): RDD[Row] = attachTree(this, "sort") { child.execute().mapPartitions( { iterator => val ordering = newOrdering(sortOrder, child.output) iterator.map(_.copy()).toArray.sorted(ordering).iterator }, preservesPartitioning = true) } - override def output = child.output + override def output: Seq[Attribute] = child.output } /** @@ -182,10 +185,11 @@ case class ExternalSort( global: Boolean, child: SparkPlan) extends UnaryNode { - override def requiredChildDistribution = + + override def requiredChildDistribution: Seq[Distribution] = if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil - override def execute() = attachTree(this, "sort") { + override def execute(): RDD[Row] = attachTree(this, "sort") { child.execute().mapPartitions( { iterator => val ordering = newOrdering(sortOrder, child.output) val sorter = new ExternalSorter[Row, Null, Row](ordering = Some(ordering)) @@ -194,7 +198,7 @@ case class ExternalSort( }, preservesPartitioning = true) } - override def output = child.output + override def output: Seq[Attribute] = child.output } /** @@ -206,12 +210,12 @@ case class ExternalSort( */ @DeveloperApi case class Distinct(partial: Boolean, child: SparkPlan) extends UnaryNode { - override def output = child.output + override def output: Seq[Attribute] = child.output - override def requiredChildDistribution = + override def requiredChildDistribution: Seq[Distribution] = if (partial) UnspecifiedDistribution :: Nil else ClusteredDistribution(child.output) :: Nil - override def execute() = { + override def execute(): RDD[Row] = { child.execute().mapPartitions { iter => val hashSet = new scala.collection.mutable.HashSet[Row]() @@ -236,9 +240,9 @@ case class Distinct(partial: Boolean, child: SparkPlan) extends UnaryNode { */ @DeveloperApi case class Except(left: SparkPlan, right: SparkPlan) extends BinaryNode { - override def output = left.output + override def output: Seq[Attribute] = left.output - override def execute() = { + override def execute(): RDD[Row] = { left.execute().map(_.copy()).subtract(right.execute().map(_.copy())) } } @@ -250,9 +254,9 @@ case class Except(left: SparkPlan, right: SparkPlan) extends BinaryNode { */ @DeveloperApi case class Intersect(left: SparkPlan, right: SparkPlan) extends BinaryNode { - override def output = children.head.output + override def output: Seq[Attribute] = children.head.output - override def execute() = { + override def execute(): RDD[Row] = { left.execute().map(_.copy()).intersection(right.execute().map(_.copy())) } } @@ -265,6 +269,7 @@ case class Intersect(left: SparkPlan, right: SparkPlan) extends BinaryNode { */ @DeveloperApi case class OutputFaker(output: Seq[Attribute], child: SparkPlan) extends SparkPlan { - def children = child :: Nil - def execute() = child.execute() + def children: Seq[SparkPlan] = child :: Nil + + def execute(): RDD[Row] = child.execute() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala index a11232142d0fb..fad7a281dc1e2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala @@ -26,7 +26,6 @@ import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Row, Attribute} import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import scala.collection.mutable.ArrayBuffer /** * A logical command that is executed for its side-effects. `RunnableCommand`s are @@ -54,9 +53,9 @@ case class ExecutedCommand(cmd: RunnableCommand) extends SparkPlan { */ protected[sql] lazy val sideEffectResult: Seq[Row] = cmd.run(sqlContext) - override def output = cmd.output + override def output: Seq[Attribute] = cmd.output - override def children = Nil + override def children: Seq[SparkPlan] = Nil override def executeCollect(): Array[Row] = sideEffectResult.toArray @@ -71,9 +70,10 @@ case class ExecutedCommand(cmd: RunnableCommand) extends SparkPlan { @DeveloperApi case class SetCommand( kv: Option[(String, Option[String])], - override val output: Seq[Attribute]) extends RunnableCommand with Logging { + override val output: Seq[Attribute]) + extends RunnableCommand with Logging { - override def run(sqlContext: SQLContext) = kv match { + override def run(sqlContext: SQLContext): Seq[Row] = kv match { // Configures the deprecated "mapred.reduce.tasks" property. case Some((SQLConf.Deprecated.MAPRED_REDUCE_TASKS, Some(value))) => logWarning( @@ -119,10 +119,11 @@ case class ExplainCommand( logicalPlan: LogicalPlan, override val output: Seq[Attribute] = Seq(AttributeReference("plan", StringType, nullable = false)()), - extended: Boolean = false) extends RunnableCommand { + extended: Boolean = false) + extends RunnableCommand { // Run through the optimizer to generate the physical plan. - override def run(sqlContext: SQLContext) = try { + override def run(sqlContext: SQLContext): Seq[Row] = try { // TODO in Hive, the "extended" ExplainCommand prints the AST as well, and detailed properties. val queryExecution = sqlContext.executePlan(logicalPlan) val outputString = if (extended) queryExecution.toString else queryExecution.simpleString @@ -140,9 +141,10 @@ case class ExplainCommand( case class CacheTableCommand( tableName: String, plan: Option[LogicalPlan], - isLazy: Boolean) extends RunnableCommand { + isLazy: Boolean) + extends RunnableCommand { - override def run(sqlContext: SQLContext) = { + override def run(sqlContext: SQLContext): Seq[Row] = { plan.foreach { logicalPlan => sqlContext.registerDataFrameAsTable(DataFrame(sqlContext, logicalPlan), tableName) } @@ -166,7 +168,7 @@ case class CacheTableCommand( @DeveloperApi case class UncacheTableCommand(tableName: String) extends RunnableCommand { - override def run(sqlContext: SQLContext) = { + override def run(sqlContext: SQLContext): Seq[Row] = { sqlContext.table(tableName).unpersist(blocking = false) Seq.empty[Row] } @@ -181,7 +183,7 @@ case class UncacheTableCommand(tableName: String) extends RunnableCommand { @DeveloperApi case object ClearCacheCommand extends RunnableCommand { - override def run(sqlContext: SQLContext) = { + override def run(sqlContext: SQLContext): Seq[Row] = { sqlContext.clearCache() Seq.empty[Row] } @@ -196,9 +198,10 @@ case object ClearCacheCommand extends RunnableCommand { case class DescribeCommand( child: SparkPlan, override val output: Seq[Attribute], - isExtended: Boolean) extends RunnableCommand { + isExtended: Boolean) + extends RunnableCommand { - override def run(sqlContext: SQLContext) = { + override def run(sqlContext: SQLContext): Seq[Row] = { child.schema.fields.map { field => val cmtKey = "comment" val comment = if (field.metadata.contains(cmtKey)) field.metadata.getString(cmtKey) else "" @@ -220,7 +223,7 @@ case class DescribeCommand( case class ShowTablesCommand(databaseName: Option[String]) extends RunnableCommand { // The result of SHOW TABLES has two columns, tableName and isTemporary. - override val output = { + override val output: Seq[Attribute] = { val schema = StructType( StructField("tableName", StringType, false) :: StructField("isTemporary", BooleanType, false) :: Nil) @@ -228,7 +231,7 @@ case class ShowTablesCommand(databaseName: Option[String]) extends RunnableComma schema.toAttributes } - override def run(sqlContext: SQLContext) = { + override def run(sqlContext: SQLContext): Seq[Row] = { // Since we need to return a Seq of rows, we will call getTables directly // instead of calling tables in sqlContext. val rows = sqlContext.catalog.getTables(databaseName).map { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala index ffe388cfa9532..e916e68e58b5d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala @@ -17,11 +17,13 @@ package org.apache.spark.sql.execution +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.expressions.Attribute + import scala.collection.mutable.HashSet -import org.apache.spark.{AccumulatorParam, Accumulator, SparkContext} +import org.apache.spark.{AccumulatorParam, Accumulator} import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.SparkContext._ import org.apache.spark.sql.{SQLConf, SQLContext, DataFrame, Row} import org.apache.spark.sql.catalyst.trees.TreeNodeRef import org.apache.spark.sql.types._ @@ -43,7 +45,7 @@ package object debug { * Augments [[SQLContext]] with debug methods. */ implicit class DebugSQLContext(sqlContext: SQLContext) { - def debug() = { + def debug(): Unit = { sqlContext.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, "false") } } @@ -88,7 +90,7 @@ package object debug { } private[sql] case class DebugNode(child: SparkPlan) extends UnaryNode { - def output = child.output + def output: Seq[Attribute] = child.output implicit object SetAccumulatorParam extends AccumulatorParam[HashSet[String]] { def zero(initialValue: HashSet[String]): HashSet[String] = { @@ -109,10 +111,10 @@ package object debug { */ case class ColumnMetrics( elementTypes: Accumulator[HashSet[String]] = sparkContext.accumulator(HashSet.empty)) - val tupleCount = sparkContext.accumulator[Int](0) + val tupleCount: Accumulator[Int] = sparkContext.accumulator[Int](0) - val numColumns = child.output.size - val columnStats = Array.fill(child.output.size)(new ColumnMetrics()) + val numColumns: Int = child.output.size + val columnStats: Array[ColumnMetrics] = Array.fill(child.output.size)(new ColumnMetrics()) def dumpStats(): Unit = { println(s"== ${child.simpleString} ==") @@ -123,11 +125,11 @@ package object debug { } } - def execute() = { + def execute(): RDD[Row] = { child.execute().mapPartitions { iter => new Iterator[Row] { - def hasNext = iter.hasNext - def next() = { + def hasNext: Boolean = iter.hasNext + def next(): Row = { val currentRow = iter.next() tupleCount += 1 var i = 0 @@ -180,18 +182,18 @@ package object debug { private[sql] case class TypeCheck(child: SparkPlan) extends SparkPlan { import TypeCheck._ - override def nodeName = "" + override def nodeName: String = "" /* Only required when defining this class in a REPL. override def makeCopy(args: Array[Object]): this.type = TypeCheck(args(0).asInstanceOf[SparkPlan]).asInstanceOf[this.type] */ - def output = child.output + def output: Seq[Attribute] = child.output - def children = child :: Nil + def children: List[SparkPlan] = child :: Nil - def execute() = { + def execute(): RDD[Row] = { child.execute().map { row => try typeCheck(row, child.schema) catch { case e: Exception => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala index 2dd22c020ef12..926f5e6c137ee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala @@ -17,13 +17,15 @@ package org.apache.spark.sql.execution.joins +import org.apache.spark.rdd.RDD + import scala.concurrent._ import scala.concurrent.duration._ import scala.concurrent.ExecutionContext.Implicits.global import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.catalyst.expressions.{Row, Expression} -import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnspecifiedDistribution} +import org.apache.spark.sql.catalyst.plans.physical.{Distribution, Partitioning, UnspecifiedDistribution} import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} /** @@ -42,7 +44,7 @@ case class BroadcastHashJoin( right: SparkPlan) extends BinaryNode with HashJoin { - val timeout = { + val timeout: Duration = { val timeoutValue = sqlContext.conf.broadcastTimeout if (timeoutValue < 0) { Duration.Inf @@ -53,7 +55,7 @@ case class BroadcastHashJoin( override def outputPartitioning: Partitioning = streamedPlan.outputPartitioning - override def requiredChildDistribution = + override def requiredChildDistribution: Seq[Distribution] = UnspecifiedDistribution :: UnspecifiedDistribution :: Nil @transient @@ -64,7 +66,7 @@ case class BroadcastHashJoin( sparkContext.broadcast(hashed) } - override def execute() = { + override def execute(): RDD[Row] = { val broadcastRelation = Await.result(broadcastFuture, timeout) streamedPlan.execute().mapPartitions { streamedIter => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala index 2ab064fd0151e..3ef1e0d7fbdd4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala @@ -18,8 +18,8 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.sql.catalyst.expressions.{Expression, Row} -import org.apache.spark.sql.catalyst.plans.physical.ClusteredDistribution +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, Row} import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} /** @@ -34,11 +34,11 @@ case class BroadcastLeftSemiJoinHash( left: SparkPlan, right: SparkPlan) extends BinaryNode with HashJoin { - override val buildSide = BuildRight + override val buildSide: BuildSide = BuildRight - override def output = left.output + override def output: Seq[Attribute] = left.output - override def execute() = { + override def execute(): RDD[Row] = { val buildIter= buildPlan.execute().map(_.copy()).collect().toIterator val hashSet = new java.util.HashSet[Row]() var currentRow: Row = null diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala index 36aad13778bd2..83b1a83765153 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.catalyst.plans.{FullOuter, JoinType, LeftOuter, RightOuter} @@ -44,7 +45,7 @@ case class BroadcastNestedLoopJoin( override def outputPartitioning: Partitioning = streamed.outputPartitioning - override def output = { + override def output: Seq[Attribute] = { joinType match { case LeftOuter => left.output ++ right.output.map(_.withNullability(true)) @@ -63,7 +64,7 @@ case class BroadcastNestedLoopJoin( .map(c => BindReferences.bindReference(c, left.output ++ right.output)) .getOrElse(Literal(true))) - override def execute() = { + override def execute(): RDD[Row] = { val broadcastedRelation = sparkContext.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala index 76c14c02aab34..1cbc98354d673 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala @@ -18,7 +18,9 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.sql.catalyst.expressions.JoinedRow +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.expressions.{Attribute, JoinedRow} import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} /** @@ -26,9 +28,9 @@ import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} */ @DeveloperApi case class CartesianProduct(left: SparkPlan, right: SparkPlan) extends BinaryNode { - override def output = left.output ++ right.output + override def output: Seq[Attribute] = left.output ++ right.output - override def execute() = { + override def execute(): RDD[Row] = { val leftResults = left.execute().map(_.copy()) val rightResults = right.execute().map(_.copy()) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index 4012d757d5f9a..851de1685509a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -41,7 +41,7 @@ trait HashJoin { case BuildRight => (rightKeys, leftKeys) } - override def output = left.output ++ right.output + override def output: Seq[Attribute] = left.output ++ right.output @transient protected lazy val buildSideKeyGenerator: Projection = newProjection(buildKeys, buildPlan.output) @@ -65,7 +65,7 @@ trait HashJoin { (currentMatchPosition != -1 && currentMatchPosition < currentHashMatches.size) || (streamIter.hasNext && fetchNext()) - override final def next() = { + override final def next(): Row = { val ret = buildSide match { case BuildRight => joinRow(currentStreamedRow, currentHashMatches(currentMatchPosition)) case BuildLeft => joinRow(currentHashMatches(currentMatchPosition), currentStreamedRow) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala index 59ef904272545..a396c0f5d56ee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.execution.joins import java.util.{HashMap => JavaHashMap} +import org.apache.spark.rdd.RDD + import scala.collection.JavaConversions._ import org.apache.spark.annotation.DeveloperApi @@ -49,10 +51,10 @@ case class HashOuterJoin( case x => throw new Exception(s"HashOuterJoin should not take $x as the JoinType") } - override def requiredChildDistribution = + override def requiredChildDistribution: Seq[ClusteredDistribution] = ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil - override def output = { + override def output: Seq[Attribute] = { joinType match { case LeftOuter => left.output ++ right.output.map(_.withNullability(true)) @@ -78,12 +80,12 @@ case class HashOuterJoin( private[this] def leftOuterIterator( key: Row, joinedRow: JoinedRow, rightIter: Iterable[Row]): Iterator[Row] = { - val ret: Iterable[Row] = ( + val ret: Iterable[Row] = { if (!key.anyNull) { val temp = rightIter.collect { - case r if (boundCondition(joinedRow.withRight(r))) => joinedRow.copy + case r if boundCondition(joinedRow.withRight(r)) => joinedRow.copy() } - if (temp.size == 0) { + if (temp.size == 0) { joinedRow.withRight(rightNullRow).copy :: Nil } else { temp @@ -91,19 +93,19 @@ case class HashOuterJoin( } else { joinedRow.withRight(rightNullRow).copy :: Nil } - ) + } ret.iterator } private[this] def rightOuterIterator( key: Row, leftIter: Iterable[Row], joinedRow: JoinedRow): Iterator[Row] = { - val ret: Iterable[Row] = ( + val ret: Iterable[Row] = { if (!key.anyNull) { val temp = leftIter.collect { - case l if (boundCondition(joinedRow.withLeft(l))) => joinedRow.copy + case l if boundCondition(joinedRow.withLeft(l)) => joinedRow.copy } - if (temp.size == 0) { + if (temp.size == 0) { joinedRow.withLeft(leftNullRow).copy :: Nil } else { temp @@ -111,7 +113,7 @@ case class HashOuterJoin( } else { joinedRow.withLeft(leftNullRow).copy :: Nil } - ) + } ret.iterator } @@ -130,12 +132,12 @@ case class HashOuterJoin( // 1. For those matched (satisfy the join condition) records with both sides filled, // append them directly - case (r, idx) if (boundCondition(joinedRow.withRight(r)))=> { + case (r, idx) if boundCondition(joinedRow.withRight(r)) => matched = true // if the row satisfy the join condition, add its index into the matched set rightMatchedSet.add(idx) - joinedRow.copy - } + joinedRow.copy() + } ++ DUMMY_LIST.filter(_ => !matched).map( _ => { // 2. For those unmatched records in left, append additional records with empty right. @@ -143,22 +145,21 @@ case class HashOuterJoin( // as we don't know whether we need to append it until finish iterating all // of the records in right side. // If we didn't get any proper row, then append a single row with empty right. - joinedRow.withRight(rightNullRow).copy + joinedRow.withRight(rightNullRow).copy() }) } ++ rightIter.zipWithIndex.collect { // 3. For those unmatched records in right, append additional records with empty left. // Re-visiting the records in right, and append additional row with empty left, if its not // in the matched set. - case (r, idx) if (!rightMatchedSet.contains(idx)) => { - joinedRow(leftNullRow, r).copy - } + case (r, idx) if !rightMatchedSet.contains(idx) => + joinedRow(leftNullRow, r).copy() } } else { leftIter.iterator.map[Row] { l => - joinedRow(l, rightNullRow).copy + joinedRow(l, rightNullRow).copy() } ++ rightIter.iterator.map[Row] { r => - joinedRow(leftNullRow, r).copy + joinedRow(leftNullRow, r).copy() } } } @@ -182,13 +183,13 @@ case class HashOuterJoin( hashTable } - override def execute() = { + override def execute(): RDD[Row] = { val joinedRow = new JoinedRow() left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => // TODO this probably can be replaced by external sort (sort merged join?) joinType match { - case LeftOuter => { + case LeftOuter => val rightHashTable = buildHashTable(rightIter, newProjection(rightKeys, right.output)) val keyGenerator = newProjection(leftKeys, left.output) leftIter.flatMap( currentRow => { @@ -196,8 +197,8 @@ case class HashOuterJoin( joinedRow.withLeft(currentRow) leftOuterIterator(rowKey, joinedRow, rightHashTable.getOrElse(rowKey, EMPTY_LIST)) }) - } - case RightOuter => { + + case RightOuter => val leftHashTable = buildHashTable(leftIter, newProjection(leftKeys, left.output)) val keyGenerator = newProjection(rightKeys, right.output) rightIter.flatMap ( currentRow => { @@ -205,8 +206,8 @@ case class HashOuterJoin( joinedRow.withRight(currentRow) rightOuterIterator(rowKey, leftHashTable.getOrElse(rowKey, EMPTY_LIST), joinedRow) }) - } - case FullOuter => { + + case FullOuter => val leftHashTable = buildHashTable(leftIter, newProjection(leftKeys, left.output)) val rightHashTable = buildHashTable(rightIter, newProjection(rightKeys, right.output)) (leftHashTable.keySet ++ rightHashTable.keySet).iterator.flatMap { key => @@ -214,7 +215,7 @@ case class HashOuterJoin( leftHashTable.getOrElse(key, EMPTY_LIST), rightHashTable.getOrElse(key, EMPTY_LIST), joinedRow) } - } + case x => throw new Exception(s"HashOuterJoin should not take $x as the JoinType") } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 38b8993b03f82..2fa1cf5add3b5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -38,7 +38,7 @@ private[joins] sealed trait HashedRelation { private[joins] final class GeneralHashedRelation(hashTable: JavaHashMap[Row, CompactBuffer[Row]]) extends HashedRelation with Serializable { - override def get(key: Row) = hashTable.get(key) + override def get(key: Row): CompactBuffer[Row] = hashTable.get(key) } @@ -49,7 +49,7 @@ private[joins] final class GeneralHashedRelation(hashTable: JavaHashMap[Row, Com private[joins] final class UniqueKeyHashedRelation(hashTable: JavaHashMap[Row, Row]) extends HashedRelation with Serializable { - override def get(key: Row) = { + override def get(key: Row): CompactBuffer[Row] = { val v = hashTable.get(key) if (v eq null) null else CompactBuffer(v) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala index 60003d1900d85..1fa7e7bd0406c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} @@ -35,12 +36,13 @@ case class LeftSemiJoinBNL( override def outputPartitioning: Partitioning = streamed.outputPartitioning - override def output = left.output + override def output: Seq[Attribute] = left.output /** The Streamed Relation */ - override def left = streamed + override def left: SparkPlan = streamed + /** The Broadcast relation */ - override def right = broadcast + override def right: SparkPlan = broadcast @transient private lazy val boundCondition = InterpretedPredicate( @@ -48,7 +50,7 @@ case class LeftSemiJoinBNL( .map(c => BindReferences.bindReference(c, left.output ++ right.output)) .getOrElse(Literal(true))) - override def execute() = { + override def execute(): RDD[Row] = { val broadcastedRelation = sparkContext.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala index ea7babf3be948..a04f2a63b5a55 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala @@ -18,7 +18,8 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.sql.catalyst.expressions.{Expression, Row} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, Row} import org.apache.spark.sql.catalyst.plans.physical.ClusteredDistribution import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} @@ -34,14 +35,14 @@ case class LeftSemiJoinHash( left: SparkPlan, right: SparkPlan) extends BinaryNode with HashJoin { - override val buildSide = BuildRight + override val buildSide: BuildSide = BuildRight - override def requiredChildDistribution = + override def requiredChildDistribution: Seq[ClusteredDistribution] = ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil - override def output = left.output + override def output: Seq[Attribute] = left.output - override def execute() = { + override def execute(): RDD[Row] = { buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) => val hashSet = new java.util.HashSet[Row]() var currentRow: Row = null diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala index 418c1c23e5546..a6cd8337c1c3e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala @@ -18,6 +18,8 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Partitioning} import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} @@ -38,10 +40,10 @@ case class ShuffledHashJoin( override def outputPartitioning: Partitioning = left.outputPartitioning - override def requiredChildDistribution = + override def requiredChildDistribution: Seq[ClusteredDistribution] = ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil - override def execute() = { + override def execute(): RDD[Row] = { buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) => val hashed = HashedRelation(buildIter, buildSideKeyGenerator) hashJoin(streamIter, hashed) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala index 33632b8e82ff9..5b308d88d4cdf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.execution import java.util.{List => JList, Map => JMap} +import org.apache.spark.rdd.RDD + import scala.collection.JavaConversions._ import scala.collection.JavaConverters._ @@ -48,11 +50,13 @@ private[spark] case class PythonUDF( dataType: DataType, children: Seq[Expression]) extends Expression with SparkLogging { - override def toString = s"PythonUDF#$name(${children.mkString(",")})" + override def toString: String = s"PythonUDF#$name(${children.mkString(",")})" def nullable: Boolean = true - override def eval(input: Row) = sys.error("PythonUDFs can not be directly evaluated.") + override def eval(input: Row): PythonUDF.this.EvaluatedType = { + sys.error("PythonUDFs can not be directly evaluated.") + } } /** @@ -63,7 +67,7 @@ private[spark] case class PythonUDF( * multiple child operators. */ private[spark] object ExtractPythonUdfs extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan) = plan transform { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { // Skip EvaluatePython nodes. case p: EvaluatePython => p @@ -107,7 +111,7 @@ private[spark] object ExtractPythonUdfs extends Rule[LogicalPlan] { } object EvaluatePython { - def apply(udf: PythonUDF, child: LogicalPlan) = + def apply(udf: PythonUDF, child: LogicalPlan): EvaluatePython = new EvaluatePython(udf, child, AttributeReference("pythonUDF", udf.dataType)()) /** @@ -205,10 +209,10 @@ case class EvaluatePython( resultAttribute: AttributeReference) extends logical.UnaryNode { - def output = child.output :+ resultAttribute + def output: Seq[Attribute] = child.output :+ resultAttribute // References should not include the produced attribute. - override def references = udf.references + override def references: AttributeSet = udf.references } /** @@ -219,9 +223,10 @@ case class EvaluatePython( @DeveloperApi case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: SparkPlan) extends SparkPlan { - def children = child :: Nil - def execute() = { + def children: Seq[SparkPlan] = child :: Nil + + def execute(): RDD[Row] = { // TODO: Clean up after ourselves? val childResults = child.execute().map(_.copy()).cache() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala index 87304ce2496b4..3266b972128ea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala @@ -306,7 +306,8 @@ private[sql] class JDBCRDD( /** * Runs the SQL query against the JDBC driver. */ - override def compute(thePart: Partition, context: TaskContext) = new Iterator[Row] { + override def compute(thePart: Partition, context: TaskContext): Iterator[Row] = new Iterator[Row] + { var closed = false var finished = false var gotNext = false diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala index 1778d39c42e2b..df687e6da9bea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala @@ -17,6 +17,10 @@ package org.apache.spark.sql.jdbc +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.expressions.Row +import org.apache.spark.sql.types.StructType + import scala.collection.mutable.ArrayBuffer import java.sql.DriverManager @@ -122,9 +126,9 @@ private[sql] case class JDBCRelation( extends BaseRelation with PrunedFilteredScan { - override val schema = JDBCRDD.resolveTable(url, table) + override val schema: StructType = JDBCRDD.resolveTable(url, table) - override def buildScan(requiredColumns: Array[String], filters: Array[Filter]) = { + override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = { val driver: String = DriverManager.getDriver(url).getClass.getCanonicalName JDBCRDD.scanTable( sqlContext.sparkContext, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala index b645199ded18c..b1e363d02edfe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala @@ -20,6 +20,8 @@ package org.apache.spark.sql.json import java.io.IOException import org.apache.hadoop.fs.Path +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.expressions.Row import org.apache.spark.sql.{SaveMode, DataFrame, SQLContext} import org.apache.spark.sql.sources._ @@ -104,10 +106,10 @@ private[sql] case class JSONRelation( samplingRatio, sqlContext.conf.columnNameOfCorruptRecord))) - override def buildScan() = + override def buildScan(): RDD[Row] = JsonRDD.jsonStringToRow(baseRDD, schema, sqlContext.conf.columnNameOfCorruptRecord) - override def insert(data: DataFrame, overwrite: Boolean) = { + override def insert(data: DataFrame, overwrite: Boolean): Unit = { val filesystemPath = new Path(path) val fs = filesystemPath.getFileSystem(sqlContext.sparkContext.hadoopConfiguration) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala index 7d62f3728f036..f898e4b37a56b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala @@ -488,7 +488,7 @@ private[parquet] object CatalystTimestampConverter { // Also we use NanoTime and Int96Values from parquet-examples. // We utilize jodd to convert between NanoTime and Timestamp val parquetTsCalendar = new ThreadLocal[Calendar] - def getCalendar = { + def getCalendar: Calendar = { // this is a cache for the calendar instance. if (parquetTsCalendar.get == null) { parquetTsCalendar.set(Calendar.getInstance(TimeZone.getTimeZone("GMT"))) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala index fd161bae128ad..fcb9513ab66f6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala @@ -71,16 +71,22 @@ private[sql] case class ParquetRelation( sqlContext.conf.isParquetINT96AsTimestamp) lazy val attributeMap = AttributeMap(output.map(o => o -> o)) - override def newInstance() = ParquetRelation(path, conf, sqlContext).asInstanceOf[this.type] + override def newInstance(): this.type = { + ParquetRelation(path, conf, sqlContext).asInstanceOf[this.type] + } // Equals must also take into account the output attributes so that we can distinguish between // different instances of the same relation, - override def equals(other: Any) = other match { + override def equals(other: Any): Boolean = other match { case p: ParquetRelation => p.path == path && p.output == output case _ => false } + override def hashCode: Int = { + com.google.common.base.Objects.hashCode(path, output) + } + // TODO: Use data from the footers. override lazy val statistics = Statistics(sizeInBytes = sqlContext.conf.defaultSizeInBytes) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala index 62813a981e685..5130d8ad5e003 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala @@ -152,8 +152,8 @@ private[sql] case class ParquetTableScan( if (primitiveRow) { new Iterator[Row] { - def hasNext = iter.hasNext - def next() = { + def hasNext: Boolean = iter.hasNext + def next(): Row = { // We are using CatalystPrimitiveRowConverter and it returns a SpecificMutableRow. val row = iter.next()._2.asInstanceOf[SpecificMutableRow] @@ -171,8 +171,8 @@ private[sql] case class ParquetTableScan( // Create a mutable row since we need to fill in values from partition columns. val mutableRow = new GenericMutableRow(outputSize) new Iterator[Row] { - def hasNext = iter.hasNext - def next() = { + def hasNext: Boolean = iter.hasNext + def next(): Row = { // We are using CatalystGroupConverter and it returns a GenericRow. // Since GenericRow is not mutable, we just cast it to a Row. val row = iter.next()._2.asInstanceOf[Row] @@ -255,7 +255,7 @@ private[sql] case class InsertIntoParquetTable( /** * Inserts all rows into the Parquet file. */ - override def execute() = { + override def execute(): RDD[Row] = { // TODO: currently we do not check whether the "schema"s are compatible // That means if one first creates a table and then INSERTs data with // and incompatible schema the execution will fail. It would be nice @@ -302,7 +302,7 @@ private[sql] case class InsertIntoParquetTable( childRdd } - override def output = child.output + override def output: Seq[Attribute] = child.output /** * Stores the given Row RDD as a Hadoop file. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala index c38b6e8c61d8a..10b8876c1d31c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala @@ -181,7 +181,7 @@ private[sql] case class ParquetRelation2( private val defaultPartitionName = parameters.getOrElse( ParquetRelation2.DEFAULT_PARTITION_NAME, "__HIVE_DEFAULT_PARTITION__") - override def equals(other: Any) = other match { + override def equals(other: Any): Boolean = other match { case relation: ParquetRelation2 => // If schema merging is required, we don't compare the actual schemas since they may evolve. val schemaEquality = if (shouldMergeSchemas) { @@ -198,6 +198,23 @@ private[sql] case class ParquetRelation2( case _ => false } + override def hashCode(): Int = { + if (shouldMergeSchemas) { + com.google.common.base.Objects.hashCode( + shouldMergeSchemas: java.lang.Boolean, + paths.toSet, + maybeMetastoreSchema, + maybePartitionSpec) + } else { + com.google.common.base.Objects.hashCode( + shouldMergeSchemas: java.lang.Boolean, + schema, + paths.toSet, + maybeMetastoreSchema, + maybePartitionSpec) + } + } + private[sql] def sparkContext = sqlContext.sparkContext private class MetadataCache { @@ -370,19 +387,19 @@ private[sql] case class ParquetRelation2( @transient private val metadataCache = new MetadataCache metadataCache.refresh() - def partitionSpec = metadataCache.partitionSpec + def partitionSpec: PartitionSpec = metadataCache.partitionSpec - def partitionColumns = metadataCache.partitionSpec.partitionColumns + def partitionColumns: StructType = metadataCache.partitionSpec.partitionColumns - def partitions = metadataCache.partitionSpec.partitions + def partitions: Seq[Partition] = metadataCache.partitionSpec.partitions - def isPartitioned = partitionColumns.nonEmpty + def isPartitioned: Boolean = partitionColumns.nonEmpty private def partitionKeysIncludedInDataSchema = metadataCache.partitionKeysIncludedInParquetSchema private def parquetSchema = metadataCache.parquetSchema - override def schema = metadataCache.schema + override def schema: StructType = metadataCache.schema private def isSummaryFile(file: Path): Boolean = { file.getName == ParquetFileWriter.PARQUET_COMMON_METADATA_FILE || @@ -425,8 +442,10 @@ private[sql] case class ParquetRelation2( .foreach(ParquetInputFormat.setFilterPredicate(jobConf, _)) if (isPartitioned) { - def percentRead = selectedPartitions.size.toDouble / partitions.size.toDouble * 100 - logInfo(s"Reading $percentRead% of partitions") + logInfo { + val percentRead = selectedPartitions.size.toDouble / partitions.size.toDouble * 100 + s"Reading $percentRead% of partitions" + } } val requiredColumns = output.map(_.name) @@ -703,7 +722,7 @@ private[sql] object ParquetRelation2 { private[parquet] def mergeMetastoreParquetSchema( metastoreSchema: StructType, parquetSchema: StructType): StructType = { - def schemaConflictMessage = + def schemaConflictMessage: String = s"""Converting Hive Metastore Parquet, but detected conflicting schemas. Metastore schema: |${metastoreSchema.prettyJson} | diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/timestamp/NanoTime.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/timestamp/NanoTime.scala index e24475292ceaf..70bcca7526aae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/timestamp/NanoTime.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/timestamp/NanoTime.scala @@ -26,7 +26,7 @@ private[parquet] class NanoTime extends Serializable { private var julianDay = 0 private var timeOfDayNanos = 0L - def set(julianDay: Int, timeOfDayNanos: Long) = { + def set(julianDay: Int, timeOfDayNanos: Long): this.type = { this.julianDay = julianDay this.timeOfDayNanos = timeOfDayNanos this @@ -45,11 +45,11 @@ private[parquet] class NanoTime extends Serializable { Binary.fromByteBuffer(buf) } - def writeValue(recordConsumer: RecordConsumer) { + def writeValue(recordConsumer: RecordConsumer): Unit = { recordConsumer.addBinary(toBinary) } - override def toString = + override def toString: String = "NanoTime{julianDay=" + julianDay + ", timeOfDayNanos=" + timeOfDayNanos + "}" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/LogicalRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/LogicalRelation.scala index 12b59ba20bb10..f374abffdd505 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/LogicalRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/LogicalRelation.scala @@ -30,24 +30,28 @@ private[sql] case class LogicalRelation(relation: BaseRelation) override val output: Seq[AttributeReference] = relation.schema.toAttributes // Logical Relations are distinct if they have different output for the sake of transformations. - override def equals(other: Any) = other match { + override def equals(other: Any): Boolean = other match { case l @ LogicalRelation(otherRelation) => relation == otherRelation && output == l.output case _ => false } - override def sameResult(otherPlan: LogicalPlan) = otherPlan match { + override def hashCode: Int = { + com.google.common.base.Objects.hashCode(relation, output) + } + + override def sameResult(otherPlan: LogicalPlan): Boolean = otherPlan match { case LogicalRelation(otherRelation) => relation == otherRelation case _ => false } - @transient override lazy val statistics = Statistics( + @transient override lazy val statistics: Statistics = Statistics( sizeInBytes = BigInt(relation.sizeInBytes) ) /** Used to lookup original attribute capitalization */ - val attributeMap = AttributeMap(output.map(o => (o, o))) + val attributeMap: AttributeMap[AttributeReference] = AttributeMap(output.map(o => (o, o))) - def newInstance() = LogicalRelation(relation).asInstanceOf[this.type] + def newInstance(): this.type = LogicalRelation(relation).asInstanceOf[this.type] - override def simpleString = s"Relation[${output.mkString(",")}] $relation" + override def simpleString: String = s"Relation[${output.mkString(",")}] $relation" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala index 0e540dad81283..9bbe06e59ba30 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala @@ -27,7 +27,7 @@ private[sql] case class InsertIntoDataSource( overwrite: Boolean) extends RunnableCommand { - override def run(sqlContext: SQLContext) = { + override def run(sqlContext: SQLContext): Seq[Row] = { val relation = logicalRelation.relation.asInstanceOf[InsertableRelation] val data = DataFrame(sqlContext, query) // Apply the schema of the existing table to the new data. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala index 76754a6ce4617..d57406645eefa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala @@ -362,7 +362,7 @@ private[sql] case class CreateTableUsingAsSelect( mode: SaveMode, options: Map[String, String], child: LogicalPlan) extends UnaryNode { - override def output = Seq.empty[Attribute] + override def output: Seq[Attribute] = Seq.empty[Attribute] // TODO: Override resolved after we support databaseName. // override lazy val resolved = databaseName != None && childrenResolved } @@ -373,7 +373,7 @@ private[sql] case class CreateTempTableUsing( provider: String, options: Map[String, String]) extends RunnableCommand { - def run(sqlContext: SQLContext) = { + def run(sqlContext: SQLContext): Seq[Row] = { val resolved = ResolvedDataSource(sqlContext, userSpecifiedSchema, provider, options) sqlContext.registerDataFrameAsTable( DataFrame(sqlContext, LogicalRelation(resolved.relation)), tableName) @@ -388,7 +388,7 @@ private[sql] case class CreateTempTableUsingAsSelect( options: Map[String, String], query: LogicalPlan) extends RunnableCommand { - def run(sqlContext: SQLContext) = { + def run(sqlContext: SQLContext): Seq[Row] = { val df = DataFrame(sqlContext, query) val resolved = ResolvedDataSource(sqlContext, provider, mode, options, df) sqlContext.registerDataFrameAsTable( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/rules.scala index cfa58f1442218..5a78001117d1b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/rules.scala @@ -53,7 +53,7 @@ private[sql] object PreInsertCastAndRename extends Rule[LogicalPlan] { def castAndRenameChildOutput( insertInto: InsertIntoTable, expectedOutput: Seq[Attribute], - child: LogicalPlan) = { + child: LogicalPlan): InsertIntoTable = { val newChildOutput = expectedOutput.zip(child.output).map { case (expected, actual) => val needCast = !expected.dataType.sameType(actual.dataType) @@ -79,7 +79,7 @@ private[sql] object PreInsertCastAndRename extends Rule[LogicalPlan] { * A rule to do various checks before inserting into or writing to a data source table. */ private[sql] case class PreWriteCheck(catalog: Catalog) extends (LogicalPlan => Unit) { - def failAnalysis(msg: String) = { throw new AnalysisException(msg) } + def failAnalysis(msg: String): Unit = { throw new AnalysisException(msg) } def apply(plan: LogicalPlan): Unit = { plan.foreach { From 25e271d9fbb3394931d23822a1b2020e9d9b46b3 Mon Sep 17 00:00:00 2001 From: MechCoder Date: Fri, 20 Mar 2015 17:14:09 -0700 Subject: [PATCH 23/47] [SPARK-6025] [MLlib] Add helper method evaluateEachIteration to extract learning curve Added evaluateEachIteration to allow the user to manually extract the error for each iteration of GradientBoosting. The internal optimisation can be dealt with later. Author: MechCoder Closes #4906 from MechCoder/spark-6025 and squashes the following commits: 67146ab [MechCoder] Minor 352001f [MechCoder] Minor 6e8aa10 [MechCoder] Made the following changes Used mapPartition instead of map Refactored computeError and unpersisted broadcast variables bc99ac6 [MechCoder] Refactor the method and stuff dbda033 [MechCoder] [SPARK-6025] Add helper method evaluateEachIteration to extract learning curve --- docs/mllib-ensembles.md | 4 +- .../spark/mllib/tree/loss/AbsoluteError.scala | 17 ++---- .../spark/mllib/tree/loss/LogLoss.scala | 20 ++----- .../apache/spark/mllib/tree/loss/Loss.scala | 14 ++++- .../spark/mllib/tree/loss/SquaredError.scala | 17 ++---- .../mllib/tree/model/treeEnsembleModels.scala | 54 +++++++++++++++++++ .../tree/GradientBoostedTreesSuite.scala | 16 +++++- 7 files changed, 96 insertions(+), 46 deletions(-) diff --git a/docs/mllib-ensembles.md b/docs/mllib-ensembles.md index cbfb682609af3..7521fb14a7bd6 100644 --- a/docs/mllib-ensembles.md +++ b/docs/mllib-ensembles.md @@ -464,8 +464,8 @@ first one being the training dataset and the second being the validation dataset The training is stopped when the improvement in the validation error is not more than a certain tolerance (supplied by the `validationTol` argument in `BoostingStrategy`). In practice, the validation error decreases initially and later increases. There might be cases in which the validation error does not change monotonically, -and the user is advised to set a large enough negative tolerance and examine the validation curve to to tune the number of -iterations. +and the user is advised to set a large enough negative tolerance and examine the validation curve using `evaluateEachIteration` +(which gives the error or loss per iteration) to tune the number of iterations. ### Examples diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala index d1bde15e6b150..793dd664c5d5a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala @@ -47,18 +47,9 @@ object AbsoluteError extends Loss { if ((point.label - model.predict(point.features)) < 0) 1.0 else -1.0 } - /** - * Method to calculate loss of the base learner for the gradient boosting calculation. - * Note: This method is not used by the gradient boosting algorithm but is useful for debugging - * purposes. - * @param model Ensemble model - * @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. - * @return Mean absolute error of model on data - */ - override def computeError(model: TreeEnsembleModel, data: RDD[LabeledPoint]): Double = { - data.map { y => - val err = model.predict(y.features) - y.label - math.abs(err) - }.mean() + override def computeError(prediction: Double, label: Double): Double = { + val err = label - prediction + math.abs(err) } + } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala index 55213e695638c..51b1aed167b66 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala @@ -50,20 +50,10 @@ object LogLoss extends Loss { - 4.0 * point.label / (1.0 + math.exp(2.0 * point.label * prediction)) } - /** - * Method to calculate loss of the base learner for the gradient boosting calculation. - * Note: This method is not used by the gradient boosting algorithm but is useful for debugging - * purposes. - * @param model Ensemble model - * @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. - * @return Mean log loss of model on data - */ - override def computeError(model: TreeEnsembleModel, data: RDD[LabeledPoint]): Double = { - data.map { case point => - val prediction = model.predict(point.features) - val margin = 2.0 * point.label * prediction - // The following is equivalent to 2.0 * log(1 + exp(-margin)) but more numerically stable. - 2.0 * MLUtils.log1pExp(-margin) - }.mean() + override def computeError(prediction: Double, label: Double): Double = { + val margin = 2.0 * label * prediction + // The following is equivalent to 2.0 * log(1 + exp(-margin)) but more numerically stable. + 2.0 * MLUtils.log1pExp(-margin) } + } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala index e1169d9f66ea4..357869ff6b333 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala @@ -47,6 +47,18 @@ trait Loss extends Serializable { * @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. * @return Measure of model error on data */ - def computeError(model: TreeEnsembleModel, data: RDD[LabeledPoint]): Double + def computeError(model: TreeEnsembleModel, data: RDD[LabeledPoint]): Double = { + data.map(point => computeError(model.predict(point.features), point.label)).mean() + } + + /** + * Method to calculate loss when the predictions are already known. + * Note: This method is used in the method evaluateEachIteration to avoid recomputing the + * predicted values from previously fit trees. + * @param prediction Predicted label. + * @param label True label. + * @return Measure of model error on datapoint. + */ + def computeError(prediction: Double, label: Double): Double } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala index 50ecaa2f86f35..b990707ca4525 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala @@ -47,18 +47,9 @@ object SquaredError extends Loss { 2.0 * (model.predict(point.features) - point.label) } - /** - * Method to calculate loss of the base learner for the gradient boosting calculation. - * Note: This method is not used by the gradient boosting algorithm but is useful for debugging - * purposes. - * @param model Ensemble model - * @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. - * @return Mean squared error of model on data - */ - override def computeError(model: TreeEnsembleModel, data: RDD[LabeledPoint]): Double = { - data.map { y => - val err = model.predict(y.features) - y.label - err * err - }.mean() + override def computeError(prediction: Double, label: Double): Double = { + val err = prediction - label + err * err } + } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala index f160852c69c77..1950254b2aa6d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala @@ -28,9 +28,11 @@ import org.apache.spark.{Logging, SparkContext} import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.Algo import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.EnsembleCombiningStrategy._ +import org.apache.spark.mllib.tree.loss.Loss import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd.RDD import org.apache.spark.sql.SQLContext @@ -108,6 +110,58 @@ class GradientBoostedTreesModel( } override protected def formatVersion: String = TreeEnsembleModel.SaveLoadV1_0.thisFormatVersion + + /** + * Method to compute error or loss for every iteration of gradient boosting. + * @param data RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] + * @param loss evaluation metric. + * @return an array with index i having the losses or errors for the ensemble + * containing the first i+1 trees + */ + def evaluateEachIteration( + data: RDD[LabeledPoint], + loss: Loss): Array[Double] = { + + val sc = data.sparkContext + val remappedData = algo match { + case Classification => data.map(x => new LabeledPoint((x.label * 2) - 1, x.features)) + case _ => data + } + + val numIterations = trees.length + val evaluationArray = Array.fill(numIterations)(0.0) + + var predictionAndError: RDD[(Double, Double)] = remappedData.map { i => + val pred = treeWeights(0) * trees(0).predict(i.features) + val error = loss.computeError(pred, i.label) + (pred, error) + } + evaluationArray(0) = predictionAndError.values.mean() + + // Avoid the model being copied across numIterations. + val broadcastTrees = sc.broadcast(trees) + val broadcastWeights = sc.broadcast(treeWeights) + + (1 until numIterations).map { nTree => + predictionAndError = remappedData.zip(predictionAndError).mapPartitions { iter => + val currentTree = broadcastTrees.value(nTree) + val currentTreeWeight = broadcastWeights.value(nTree) + iter.map { + case (point, (pred, error)) => { + val newPred = pred + currentTree.predict(point.features) * currentTreeWeight + val newError = loss.computeError(newPred, point.label) + (newPred, newError) + } + } + } + evaluationArray(nTree) = predictionAndError.values.mean() + } + + broadcastTrees.unpersist() + broadcastWeights.unpersist() + evaluationArray + } + } object GradientBoostedTreesModel extends Loader[GradientBoostedTreesModel] { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala index b437aeaaf0547..55b0bac7d49fe 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala @@ -175,10 +175,11 @@ class GradientBoostedTreesSuite extends FunSuite with MLlibTestSparkContext { new BoostingStrategy(treeStrategy, loss, numIterations, validationTol = 0.0) val gbtValidate = new GradientBoostedTrees(boostingStrategy) .runWithValidation(trainRdd, validateRdd) - assert(gbtValidate.numTrees !== numIterations) + val numTrees = gbtValidate.numTrees + assert(numTrees !== numIterations) // Test that it performs better on the validation dataset. - val gbt = GradientBoostedTrees.train(trainRdd, boostingStrategy) + val gbt = new GradientBoostedTrees(boostingStrategy).run(trainRdd) val (errorWithoutValidation, errorWithValidation) = { if (algo == Classification) { val remappedRdd = validateRdd.map(x => new LabeledPoint(2 * x.label - 1, x.features)) @@ -188,6 +189,17 @@ class GradientBoostedTreesSuite extends FunSuite with MLlibTestSparkContext { } } assert(errorWithValidation <= errorWithoutValidation) + + // Test that results from evaluateEachIteration comply with runWithValidation. + // Note that convergenceTol is set to 0.0 + val evaluationArray = gbt.evaluateEachIteration(validateRdd, loss) + assert(evaluationArray.length === numIterations) + assert(evaluationArray(numTrees) > evaluationArray(numTrees - 1)) + var i = 1 + while (i < numTrees) { + assert(evaluationArray(i) <= evaluationArray(i - 1)) + i += 1 + } } } } From bc37c9743e065a0c756363c7b70e88f22a6e6edd Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Sat, 21 Mar 2015 10:53:04 +0800 Subject: [PATCH 24/47] [SPARK-5821] [SQL] ParquetRelation2 CTAS should check if delete is successful Do the same check as #4610 for ParquetRelation2. Author: Yanbo Liang Closes #5107 from yanboliang/spark-5821-parquet and squashes the following commits: 7092c8d [Yanbo Liang] ParquetRelation2 CTAS should check if delete is successful --- .../apache/spark/sql/parquet/newParquet.scala | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala index 10b8876c1d31c..fbe7a419feb52 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala @@ -611,13 +611,22 @@ private[sql] case class ParquetRelation2( val destinationPath = new Path(paths.head) if (overwrite) { - try { - destinationPath.getFileSystem(conf).delete(destinationPath, true) - } catch { - case e: IOException => + val fs = destinationPath.getFileSystem(conf) + if (fs.exists(destinationPath)) { + var success: Boolean = false + try { + success = fs.delete(destinationPath, true) + } catch { + case e: IOException => + throw new IOException( + s"Unable to clear output directory ${destinationPath.toString} prior" + + s" to writing to Parquet table:\n${e.toString}") + } + if (!success) { throw new IOException( s"Unable to clear output directory ${destinationPath.toString} prior" + - s" to writing to Parquet file:\n${e.toString}") + s" to writing to Parquet table.") + } } } From 937c1e5503963e67a5412be993d30dbec6fc9883 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Sat, 21 Mar 2015 11:18:45 +0800 Subject: [PATCH 25/47] [SPARK-6315] [SQL] Also tries the case class string parser while reading Parquet schema When writing Parquet files, Spark 1.1.x persists the schema string into Parquet metadata with the result of `StructType.toString`, which was then deprecated in Spark 1.2 by a schema string in JSON format. But we still need to take the old schema format into account while reading Parquet files. [Review on Reviewable](https://reviewable.io/reviews/apache/spark/5034) Author: Cheng Lian Closes #5034 from liancheng/spark-6315 and squashes the following commits: a182f58 [Cheng Lian] Adds a regression test b9c6dbe [Cheng Lian] Also tries the case class string parser while reading Parquet schema --- .../apache/spark/sql/parquet/newParquet.scala | 23 +++++++++- .../spark/sql/parquet/ParquetIOSuite.scala | 42 +++++++++++++++++-- 2 files changed, 60 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala index fbe7a419feb52..410600b0529d3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala @@ -681,7 +681,7 @@ private[sql] case class ParquetRelation2( } } -private[sql] object ParquetRelation2 { +private[sql] object ParquetRelation2 extends Logging { // Whether we should merge schemas collected from all Parquet part-files. val MERGE_SCHEMA = "mergeSchema" @@ -701,7 +701,26 @@ private[sql] object ParquetRelation2 { .getKeyValueMetaData .toMap .get(RowReadSupport.SPARK_METADATA_KEY) - .map(DataType.fromJson(_).asInstanceOf[StructType]) + .flatMap { serializedSchema => + // Don't throw even if we failed to parse the serialized Spark schema. Just fallback to + // whatever is available. + Try(DataType.fromJson(serializedSchema)) + .recover { case _: Throwable => + logInfo( + s"Serialized Spark schema in Parquet key-value metadata is not in JSON format, " + + "falling back to the deprecated DataType.fromCaseClassString parser.") + DataType.fromCaseClassString(serializedSchema) + } + .recover { case cause: Throwable => + logWarning( + s"""Failed to parse serialized Spark schema in Parquet key-value metadata: + |\t$serializedSchema + """.stripMargin, + cause) + } + .map(_.asInstanceOf[StructType]) + .toOption + } maybeSparkSchema.getOrElse { // Falls back to Parquet schema if Spark SQL schema is absent. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala index a70b3c7ce48d3..5438095addeaf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala @@ -28,8 +28,8 @@ import parquet.example.data.simple.SimpleGroup import parquet.example.data.{Group, GroupWriter} import parquet.hadoop.api.WriteSupport import parquet.hadoop.api.WriteSupport.WriteContext -import parquet.hadoop.metadata.CompressionCodecName -import parquet.hadoop.{ParquetFileWriter, ParquetWriter} +import parquet.hadoop.metadata.{ParquetMetadata, FileMetaData, CompressionCodecName} +import parquet.hadoop.{Footer, ParquetFileWriter, ParquetWriter} import parquet.io.api.RecordConsumer import parquet.schema.{MessageType, MessageTypeParser} @@ -38,7 +38,7 @@ import org.apache.spark.sql.catalyst.expressions.Row import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext._ import org.apache.spark.sql.test.TestSQLContext.implicits._ -import org.apache.spark.sql.types.DecimalType +import org.apache.spark.sql.types._ import org.apache.spark.sql.{DataFrame, QueryTest, SQLConf, SaveMode} // Write support class for nested groups: ParquetWriter initializes GroupWriteSupport @@ -330,6 +330,42 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { } } + test("SPARK-6315 regression test") { + // Spark 1.1 and prior versions write Spark schema as case class string into Parquet metadata. + // This has been deprecated by JSON format since 1.2. Notice that, 1.3 further refactored data + // types API, and made StructType.fields an array. This makes the result of StructType.toString + // different from prior versions: there's no "Seq" wrapping the fields part in the string now. + val sparkSchema = + "StructType(Seq(StructField(a,BooleanType,false),StructField(b,IntegerType,false)))" + + // The Parquet schema is intentionally made different from the Spark schema. Because the new + // Parquet data source simply falls back to the Parquet schema once it fails to parse the Spark + // schema. By making these two different, we are able to assert the old style case class string + // is parsed successfully. + val parquetSchema = MessageTypeParser.parseMessageType( + """message root { + | required int32 c; + |} + """.stripMargin) + + withTempPath { location => + val extraMetadata = Map(RowReadSupport.SPARK_METADATA_KEY -> sparkSchema.toString) + val fileMetadata = new FileMetaData(parquetSchema, extraMetadata, "Spark") + val path = new Path(location.getCanonicalPath) + + ParquetFileWriter.writeMetadataFile( + sparkContext.hadoopConfiguration, + path, + new Footer(path, new ParquetMetadata(fileMetadata, Nil)) :: Nil) + + assertResult(parquetFile(path.toString).schema) { + StructType( + StructField("a", BooleanType, nullable = false) :: + StructField("b", IntegerType, nullable = false) :: + Nil) + } + } + } } class ParquetDataSourceOnIOSuite extends ParquetIOSuiteBase with BeforeAndAfterAll { From e5d2c37c68ac00a57c2542e62d1c5b4ca267c89e Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Sat, 21 Mar 2015 11:23:28 +0800 Subject: [PATCH 26/47] [SPARK-5821] [SQL] JSON CTAS command should throw error message when delete path failure When using "CREATE TEMPORARY TABLE AS SELECT" to create JSON table, we first delete the path file or directory and then generate a new directory with the same name. But if only read permission was granted, the delete failed. Here we just throwing an error message to let users know what happened. ParquetRelation2 may also hit this problem. I think to restrict JSONRelation and ParquetRelation2 must base on directory is more reasonable for access control. Maybe I can do it in follow up works. Author: Yanbo Liang Author: Yanbo Liang Closes #4610 from yanboliang/jsonInsertImprovements and squashes the following commits: c387fce [Yanbo Liang] fix typos 42d7fb6 [Yanbo Liang] add unittest & fix output format 46f0d9d [Yanbo Liang] Update JSONRelation.scala e2df8d5 [Yanbo Liang] check path exisit when write 79f7040 [Yanbo Liang] Update JSONRelation.scala e4bc229 [Yanbo Liang] Update JSONRelation.scala 5a42d83 [Yanbo Liang] JSONRelation CTAS should check if delete is successful --- .../apache/spark/sql/json/JSONRelation.scala | 36 +++++++++++++++---- .../sources/CreateTableAsSelectSuite.scala | 25 ++++++++++++- 2 files changed, 53 insertions(+), 8 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala index b1e363d02edfe..f4c99b4b56606 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala @@ -68,9 +68,23 @@ private[sql] class DefaultSource mode match { case SaveMode.Append => sys.error(s"Append mode is not supported by ${this.getClass.getCanonicalName}") - case SaveMode.Overwrite => - fs.delete(filesystemPath, true) + case SaveMode.Overwrite => { + var success: Boolean = false + try { + success = fs.delete(filesystemPath, true) + } catch { + case e: IOException => + throw new IOException( + s"Unable to clear output directory ${filesystemPath.toString} prior" + + s" to writing to JSON table:\n${e.toString}") + } + if (!success) { + throw new IOException( + s"Unable to clear output directory ${filesystemPath.toString} prior" + + s" to writing to JSON table.") + } true + } case SaveMode.ErrorIfExists => sys.error(s"path $path already exists.") case SaveMode.Ignore => false @@ -114,13 +128,21 @@ private[sql] case class JSONRelation( val fs = filesystemPath.getFileSystem(sqlContext.sparkContext.hadoopConfiguration) if (overwrite) { - try { - fs.delete(filesystemPath, true) - } catch { - case e: IOException => + if (fs.exists(filesystemPath)) { + var success: Boolean = false + try { + success = fs.delete(filesystemPath, true) + } catch { + case e: IOException => + throw new IOException( + s"Unable to clear output directory ${filesystemPath.toString} prior" + + s" to writing to JSON table:\n${e.toString}") + } + if (!success) { throw new IOException( s"Unable to clear output directory ${filesystemPath.toString} prior" - + s" to INSERT OVERWRITE a JSON table:\n${e.toString}") + + s" to writing to JSON table.") + } } // Write the data. data.toJSON.saveAsTextFile(path) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala index 2975a7fee4c96..20a23b3bd6aa9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.sources -import java.io.File +import java.io.{IOException, File} import org.apache.spark.sql.AnalysisException import org.scalatest.BeforeAndAfterAll @@ -62,6 +62,29 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { dropTempTable("jsonTable") } + test("CREATE TEMPORARY TABLE AS SELECT based on the file without write permission") { + val childPath = new File(path.toString, "child") + path.mkdir() + childPath.createNewFile() + path.setWritable(false) + + val e = intercept[IOException] { + sql( + s""" + |CREATE TEMPORARY TABLE jsonTable + |USING org.apache.spark.sql.json.DefaultSource + |OPTIONS ( + | path '${path.toString}' + |) AS + |SELECT a, b FROM jt + """.stripMargin) + sql("SELECT a, b FROM jsonTable").collect() + } + assert(e.getMessage().contains("Unable to clear output directory")) + + path.setWritable(true) + } + test("create a table, drop it and create another one with the same name") { sql( s""" From 52dd4b2b277eb48bc89db9b21d25f5e836c1d348 Mon Sep 17 00:00:00 2001 From: x1- Date: Sat, 21 Mar 2015 13:22:34 -0700 Subject: [PATCH 27/47] [SPARK-5320][SQL]Add statistics method at NoRelation (override super). Because of no statistics override, in spute of super class say 'LeafNode must override'. fix issue [SPARK-5320: Joins on simple table created using select gives error](https://issues.apache.org/jira/browse/SPARK-5320) Author: x1- Closes #5105 from x1-/SPARK-5320 and squashes the following commits: e561aac [x1-] Add statistics method at NoRelation (override super). --- .../sql/catalyst/plans/logical/basicOperators.scala | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 1e7b449d75b80..384fe53a68362 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -289,6 +289,15 @@ case class Distinct(child: LogicalPlan) extends UnaryNode { case object NoRelation extends LeafNode { override def output = Nil + + /** + * Computes [[Statistics]] for this plan. The default implementation assumes the output + * cardinality is the product of of all child plan's cardinality, i.e. applies in the case + * of cartesian joins. + * + * [[LeafNode]]s must override this. + */ + override def statistics: Statistics = Statistics(sizeInBytes = 1) } case class Intersect(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { From ee569a0c7171d149eee52877def902378eaf695e Mon Sep 17 00:00:00 2001 From: Venkata Ramana Gollamudi Date: Sat, 21 Mar 2015 13:24:24 -0700 Subject: [PATCH 28/47] [SPARK-5680][SQL] Sum function on all null values, should return zero SELECT sum('a'), avg('a'), variance('a'), std('a') FROM src; Should give output as 0.0 NULL NULL NULL This fixes hive udaf_number_format.q Author: Venkata Ramana G Author: Venkata Ramana Gollamudi Closes #4466 from gvramana/sum_fix and squashes the following commits: 42e14d1 [Venkata Ramana Gollamudi] Added comments 39415c0 [Venkata Ramana Gollamudi] Handled the partitioned Sum expression scenario df66515 [Venkata Ramana Gollamudi] code style fix 4be2606 [Venkata Ramana Gollamudi] Add udaf_number_format to whitelist and golden answer 330fd64 [Venkata Ramana Gollamudi] fix sum function for all null data --- .../sql/catalyst/expressions/aggregates.scala | 68 ++++++++++++++++++- .../execution/HiveCompatibilitySuite.scala | 1 + ..._format-0-eff4ef3c207d14d5121368f294697964 | 0 ..._format-1-4a03c4328565c60ca99689239f07fb16 | 1 + 4 files changed, 67 insertions(+), 3 deletions(-) create mode 100644 sql/hive/src/test/resources/golden/udaf_number_format-0-eff4ef3c207d14d5121368f294697964 create mode 100644 sql/hive/src/test/resources/golden/udaf_number_format-1-4a03c4328565c60ca99689239f07fb16 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index 735b7488fdcbd..5297d1e31246c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -346,13 +346,13 @@ case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[ case DecimalType.Fixed(_, _) => val partialSum = Alias(Sum(Cast(child, DecimalType.Unlimited)), "PartialSum")() SplitEvaluation( - Cast(Sum(partialSum.toAttribute), dataType), + Cast(CombineSum(partialSum.toAttribute), dataType), partialSum :: Nil) case _ => val partialSum = Alias(Sum(child), "PartialSum")() SplitEvaluation( - Sum(partialSum.toAttribute), + CombineSum(partialSum.toAttribute), partialSum :: Nil) } } @@ -360,6 +360,30 @@ case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[ override def newInstance() = new SumFunction(child, this) } +/** + * Sum should satisfy 3 cases: + * 1) sum of all null values = zero + * 2) sum for table column with no data = null + * 3) sum of column with null and not null values = sum of not null values + * Require separate CombineSum Expression and function as it has to distinguish "No data" case + * versus "data equals null" case, while aggregating results and at each partial expression.i.e., + * Combining PartitionLevel InputData + * <-- null + * Zero <-- Zero <-- null + * + * <-- null <-- no data + * null <-- null <-- no data + */ +case class CombineSum(child: Expression) extends AggregateExpression { + def this() = this(null) + + override def children = child :: Nil + override def nullable = true + override def dataType = child.dataType + override def toString = s"CombineSum($child)" + override def newInstance() = new CombineSumFunction(child, this) +} + case class SumDistinct(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { @@ -565,7 +589,8 @@ case class SumFunction(expr: Expression, base: AggregateExpression) extends Aggr private val sum = MutableLiteral(null, calcType) - private val addFunction = Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(expr, calcType)), sum)) + private val addFunction = + Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(expr, calcType)), sum, zero)) override def update(input: Row): Unit = { sum.update(addFunction, input) @@ -580,6 +605,43 @@ case class SumFunction(expr: Expression, base: AggregateExpression) extends Aggr } } +case class CombineSumFunction(expr: Expression, base: AggregateExpression) + extends AggregateFunction { + + def this() = this(null, null) // Required for serialization. + + private val calcType = + expr.dataType match { + case DecimalType.Fixed(_, _) => + DecimalType.Unlimited + case _ => + expr.dataType + } + + private val zero = Cast(Literal(0), calcType) + + private val sum = MutableLiteral(null, calcType) + + private val addFunction = + Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(expr, calcType)), sum, zero)) + + override def update(input: Row): Unit = { + val result = expr.eval(input) + // partial sum result can be null only when no input rows present + if(result != null) { + sum.update(addFunction, input) + } + } + + override def eval(input: Row): Any = { + expr.dataType match { + case DecimalType.Fixed(_, _) => + Cast(sum, dataType).eval(null) + case _ => sum.eval(null) + } + } +} + case class SumDistinctFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index 5180a7f09d80f..2ae9d018e1b1b 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -800,6 +800,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "udaf_covar_pop", "udaf_covar_samp", "udaf_histogram_numeric", + "udaf_number_format", "udf2", "udf5", "udf6", diff --git a/sql/hive/src/test/resources/golden/udaf_number_format-0-eff4ef3c207d14d5121368f294697964 b/sql/hive/src/test/resources/golden/udaf_number_format-0-eff4ef3c207d14d5121368f294697964 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/udaf_number_format-1-4a03c4328565c60ca99689239f07fb16 b/sql/hive/src/test/resources/golden/udaf_number_format-1-4a03c4328565c60ca99689239f07fb16 new file mode 100644 index 0000000000000..c6f275a0db131 --- /dev/null +++ b/sql/hive/src/test/resources/golden/udaf_number_format-1-4a03c4328565c60ca99689239f07fb16 @@ -0,0 +1 @@ +0.0 NULL NULL NULL From 94a102acb80a7c77f57409ece1f8dbbba791b774 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Sat, 21 Mar 2015 13:27:53 -0700 Subject: [PATCH 29/47] [SPARK-6250][SPARK-6146][SPARK-5911][SQL] Types are now reserved words in DDL parser. This PR creates a trait `DataTypeParser` used to parse data types. This trait aims to be single place to provide the functionality of parsing data types' string representation. It is currently mixed in with `DDLParser` and `SqlParser`. It is also used to parse the data type for `DataFrame.cast` and to convert Hive metastore's data type string back to a `DataType`. JIRA: https://issues.apache.org/jira/browse/SPARK-6250 Author: Yin Huai Closes #5078 from yhuai/ddlKeywords and squashes the following commits: 0e66097 [Yin Huai] Special handle struct<>. fea6012 [Yin Huai] Style. c9733fb [Yin Huai] Create a trait to parse data types. --- .../apache/spark/sql/catalyst/SqlParser.scala | 27 +--- .../spark/sql/types/DataTypeParser.scala | 115 +++++++++++++++++ .../spark/sql/types/DataTypeParserSuite.scala | 116 ++++++++++++++++++ .../scala/org/apache/spark/sql/Column.scala | 15 +-- .../org/apache/spark/sql/sources/ddl.scala | 80 +----------- .../spark/sql/hive/HiveMetastoreCatalog.scala | 8 +- 6 files changed, 241 insertions(+), 120 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeParser.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeParserSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index 54ab13ca352d2..ea7d44a3723d1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -35,7 +35,7 @@ import org.apache.spark.sql.types._ * This is currently included mostly for illustrative purposes. Users wanting more complete support * for a SQL like language should checkout the HiveQL support in the sql/hive sub-project. */ -class SqlParser extends AbstractSparkSQLParser { +class SqlParser extends AbstractSparkSQLParser with DataTypeParser { def parseExpression(input: String): Expression = { // Initialize the Keywords. @@ -61,11 +61,8 @@ class SqlParser extends AbstractSparkSQLParser { protected val CAST = Keyword("CAST") protected val COALESCE = Keyword("COALESCE") protected val COUNT = Keyword("COUNT") - protected val DATE = Keyword("DATE") - protected val DECIMAL = Keyword("DECIMAL") protected val DESC = Keyword("DESC") protected val DISTINCT = Keyword("DISTINCT") - protected val DOUBLE = Keyword("DOUBLE") protected val ELSE = Keyword("ELSE") protected val END = Keyword("END") protected val EXCEPT = Keyword("EXCEPT") @@ -78,7 +75,6 @@ class SqlParser extends AbstractSparkSQLParser { protected val IF = Keyword("IF") protected val IN = Keyword("IN") protected val INNER = Keyword("INNER") - protected val INT = Keyword("INT") protected val INSERT = Keyword("INSERT") protected val INTERSECT = Keyword("INTERSECT") protected val INTO = Keyword("INTO") @@ -105,13 +101,11 @@ class SqlParser extends AbstractSparkSQLParser { protected val SELECT = Keyword("SELECT") protected val SEMI = Keyword("SEMI") protected val SQRT = Keyword("SQRT") - protected val STRING = Keyword("STRING") protected val SUBSTR = Keyword("SUBSTR") protected val SUBSTRING = Keyword("SUBSTRING") protected val SUM = Keyword("SUM") protected val TABLE = Keyword("TABLE") protected val THEN = Keyword("THEN") - protected val TIMESTAMP = Keyword("TIMESTAMP") protected val TRUE = Keyword("TRUE") protected val UNION = Keyword("UNION") protected val UPPER = Keyword("UPPER") @@ -315,7 +309,9 @@ class SqlParser extends AbstractSparkSQLParser { ) protected lazy val cast: Parser[Expression] = - CAST ~ "(" ~> expression ~ (AS ~> dataType) <~ ")" ^^ { case exp ~ t => Cast(exp, t) } + CAST ~ "(" ~> expression ~ (AS ~> dataType) <~ ")" ^^ { + case exp ~ t => Cast(exp, t) + } protected lazy val literal: Parser[Literal] = ( numericLiteral @@ -387,19 +383,4 @@ class SqlParser extends AbstractSparkSQLParser { (ident <~ ".") ~ ident ~ rep("." ~> ident) ^^ { case i1 ~ i2 ~ rest => UnresolvedAttribute((Seq(i1, i2) ++ rest).mkString(".")) } - - protected lazy val dataType: Parser[DataType] = - ( STRING ^^^ StringType - | TIMESTAMP ^^^ TimestampType - | DOUBLE ^^^ DoubleType - | fixedDecimalType - | DECIMAL ^^^ DecimalType.Unlimited - | DATE ^^^ DateType - | INT ^^^ IntegerType - ) - - protected lazy val fixedDecimalType: Parser[DataType] = - (DECIMAL ~ "(" ~> numericLit) ~ ("," ~> numericLit <~ ")") ^^ { - case precision ~ scale => DecimalType(precision.toInt, scale.toInt) - } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeParser.scala new file mode 100644 index 0000000000000..89278f7dbc806 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeParser.scala @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types + +import scala.language.implicitConversions +import scala.util.matching.Regex +import scala.util.parsing.combinator.syntactical.StandardTokenParsers + +import org.apache.spark.sql.catalyst.SqlLexical + +/** + * This is a data type parser that can be used to parse string representations of data types + * provided in SQL queries. This parser is mixed in with DDLParser and SqlParser. + */ +private[sql] trait DataTypeParser extends StandardTokenParsers { + + // This is used to create a parser from a regex. We are using regexes for data type strings + // since these strings can be also used as column names or field names. + import lexical.Identifier + implicit def regexToParser(regex: Regex): Parser[String] = acceptMatch( + s"identifier matching regex ${regex}", + { case Identifier(str) if regex.unapplySeq(str).isDefined => str } + ) + + protected lazy val primitiveType: Parser[DataType] = + "(?i)string".r ^^^ StringType | + "(?i)float".r ^^^ FloatType | + "(?i)int".r ^^^ IntegerType | + "(?i)tinyint".r ^^^ ByteType | + "(?i)smallint".r ^^^ ShortType | + "(?i)double".r ^^^ DoubleType | + "(?i)bigint".r ^^^ LongType | + "(?i)binary".r ^^^ BinaryType | + "(?i)boolean".r ^^^ BooleanType | + fixedDecimalType | + "(?i)decimal".r ^^^ DecimalType.Unlimited | + "(?i)date".r ^^^ DateType | + "(?i)timestamp".r ^^^ TimestampType | + varchar + + protected lazy val fixedDecimalType: Parser[DataType] = + ("(?i)decimal".r ~> "(" ~> numericLit) ~ ("," ~> numericLit <~ ")") ^^ { + case precision ~ scale => + DecimalType(precision.toInt, scale.toInt) + } + + protected lazy val varchar: Parser[DataType] = + "(?i)varchar".r ~> "(" ~> (numericLit <~ ")") ^^^ StringType + + protected lazy val arrayType: Parser[DataType] = + "(?i)array".r ~> "<" ~> dataType <~ ">" ^^ { + case tpe => ArrayType(tpe) + } + + protected lazy val mapType: Parser[DataType] = + "(?i)map".r ~> "<" ~> dataType ~ "," ~ dataType <~ ">" ^^ { + case t1 ~ _ ~ t2 => MapType(t1, t2) + } + + protected lazy val structField: Parser[StructField] = + ident ~ ":" ~ dataType ^^ { + case name ~ _ ~ tpe => StructField(name, tpe, nullable = true) + } + + protected lazy val structType: Parser[DataType] = + ("(?i)struct".r ~> "<" ~> repsep(structField, ",") <~ ">" ^^ { + case fields => new StructType(fields.toArray) + }) | + ("(?i)struct".r ~ "<>" ^^^ StructType(Nil)) + + protected lazy val dataType: Parser[DataType] = + arrayType | + mapType | + structType | + primitiveType + + def toDataType(dataTypeString: String): DataType = synchronized { + phrase(dataType)(new lexical.Scanner(dataTypeString)) match { + case Success(result, _) => result + case failure: NoSuccess => throw new DataTypeException(failMessage(dataTypeString)) + } + } + + private def failMessage(dataTypeString: String): String = { + s"Unsupported dataType: $dataTypeString. If you have a struct and a field name of it has " + + "any special characters, please use backticks (`) to quote that field name, e.g. `x+y`. " + + "Please note that backtick itself is not supported in a field name." + } +} + +private[sql] object DataTypeParser { + lazy val dataTypeParser = new DataTypeParser { + override val lexical = new SqlLexical + } + + def apply(dataTypeString: String): DataType = dataTypeParser.toDataType(dataTypeString) +} + +/** The exception thrown from the [[DataTypeParser]]. */ +protected[sql] class DataTypeException(message: String) extends Exception(message) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeParserSuite.scala new file mode 100644 index 0000000000000..1ba21b64603ac --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeParserSuite.scala @@ -0,0 +1,116 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.types + +import org.scalatest.FunSuite + +class DataTypeParserSuite extends FunSuite { + + def checkDataType(dataTypeString: String, expectedDataType: DataType): Unit = { + test(s"parse ${dataTypeString.replace("\n", "")}") { + assert(DataTypeParser(dataTypeString) === expectedDataType) + } + } + + def unsupported(dataTypeString: String): Unit = { + test(s"$dataTypeString is not supported") { + intercept[DataTypeException](DataTypeParser(dataTypeString)) + } + } + + checkDataType("int", IntegerType) + checkDataType("BooLean", BooleanType) + checkDataType("tinYint", ByteType) + checkDataType("smallINT", ShortType) + checkDataType("INT", IntegerType) + checkDataType("bigint", LongType) + checkDataType("float", FloatType) + checkDataType("dOUBle", DoubleType) + checkDataType("decimal(10, 5)", DecimalType(10, 5)) + checkDataType("decimal", DecimalType.Unlimited) + checkDataType("DATE", DateType) + checkDataType("timestamp", TimestampType) + checkDataType("string", StringType) + checkDataType("varchAr(20)", StringType) + checkDataType("BINARY", BinaryType) + + checkDataType("array", ArrayType(DoubleType, true)) + checkDataType("Array>", ArrayType(MapType(IntegerType, ByteType, true), true)) + checkDataType( + "array>", + ArrayType(StructType(StructField("tinYint", ByteType, true) :: Nil), true) + ) + checkDataType("MAP", MapType(IntegerType, StringType, true)) + checkDataType("MAp>", MapType(IntegerType, ArrayType(DoubleType), true)) + checkDataType( + "MAP>", + MapType(IntegerType, StructType(StructField("varchar", StringType, true) :: Nil), true) + ) + + checkDataType( + "struct", + StructType( + StructField("intType", IntegerType, true) :: + StructField("ts", TimestampType, true) :: Nil) + ) + // It is fine to use the data type string as the column name. + checkDataType( + "Struct", + StructType( + StructField("int", IntegerType, true) :: + StructField("timestamp", TimestampType, true) :: Nil) + ) + checkDataType( + """ + |struct< + | struct:struct, + | MAP:Map, + | arrAy:Array> + """.stripMargin, + StructType( + StructField("struct", + StructType( + StructField("deciMal", DecimalType.Unlimited, true) :: + StructField("anotherDecimal", DecimalType(5, 2), true) :: Nil), true) :: + StructField("MAP", MapType(TimestampType, StringType), true) :: + StructField("arrAy", ArrayType(DoubleType, true), true) :: Nil) + ) + // A column name can be a reserved word in our DDL parser and SqlParser. + checkDataType( + "Struct", + StructType( + StructField("TABLE", StringType, true) :: + StructField("CASE", BooleanType, true) :: Nil) + ) + // Use backticks to quote column names having special characters. + checkDataType( + "struct<`x+y`:int, `!@#$%^&*()`:string, `1_2.345<>:\"`:varchar(20)>", + StructType( + StructField("x+y", IntegerType, true) :: + StructField("!@#$%^&*()", StringType, true) :: + StructField("1_2.345<>:\"", StringType, true) :: Nil) + ) + // Empty struct. + checkDataType("strUCt<>", StructType(Nil)) + + unsupported("it is not a data type") + unsupported("struct") + unsupported("struct") + unsupported("struct<`x``y` int>") +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index b7a13a1b26802..ec7d15f5bc4e7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -624,20 +624,7 @@ class Column(protected[sql] val expr: Expression) { * * @group expr_ops */ - def cast(to: String): Column = cast(to.toLowerCase match { - case "string" | "str" => StringType - case "boolean" => BooleanType - case "byte" => ByteType - case "short" => ShortType - case "int" => IntegerType - case "long" => LongType - case "float" => FloatType - case "double" => DoubleType - case "decimal" => DecimalType.Unlimited - case "date" => DateType - case "timestamp" => TimestampType - case _ => throw new RuntimeException(s"""Unsupported cast type: "$to"""") - }) + def cast(to: String): Column = cast(DataTypeParser(to)) /** * Returns an ordering used in sorting. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala index d57406645eefa..d2e807d3a69b6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala @@ -34,7 +34,8 @@ import org.apache.spark.util.Utils * A parser for foreign DDL commands. */ private[sql] class DDLParser( - parseQuery: String => LogicalPlan) extends AbstractSparkSQLParser with Logging { + parseQuery: String => LogicalPlan) + extends AbstractSparkSQLParser with DataTypeParser with Logging { def apply(input: String, exceptionOnError: Boolean): Option[LogicalPlan] = { try { @@ -46,14 +47,6 @@ private[sql] class DDLParser( } } - def parseType(input: String): DataType = { - lexical.initialize(reservedWords) - phrase(dataType)(new lexical.Scanner(input)) match { - case Success(r, x) => r - case x => throw new DDLException(s"Unsupported dataType: $x") - } - } - // Keyword is a convention with AbstractSparkSQLParser, which will scan all of the `Keyword` // properties via reflection the class in runtime for constructing the SqlLexical object protected val CREATE = Keyword("CREATE") @@ -70,24 +63,6 @@ private[sql] class DDLParser( protected val COMMENT = Keyword("COMMENT") protected val REFRESH = Keyword("REFRESH") - // Data types. - protected val STRING = Keyword("STRING") - protected val BINARY = Keyword("BINARY") - protected val BOOLEAN = Keyword("BOOLEAN") - protected val TINYINT = Keyword("TINYINT") - protected val SMALLINT = Keyword("SMALLINT") - protected val INT = Keyword("INT") - protected val BIGINT = Keyword("BIGINT") - protected val FLOAT = Keyword("FLOAT") - protected val DOUBLE = Keyword("DOUBLE") - protected val DECIMAL = Keyword("DECIMAL") - protected val DATE = Keyword("DATE") - protected val TIMESTAMP = Keyword("TIMESTAMP") - protected val VARCHAR = Keyword("VARCHAR") - protected val ARRAY = Keyword("ARRAY") - protected val MAP = Keyword("MAP") - protected val STRUCT = Keyword("STRUCT") - protected lazy val ddl: Parser[LogicalPlan] = createTable | describeTable | refreshTable protected def start: Parser[LogicalPlan] = ddl @@ -189,58 +164,9 @@ private[sql] class DDLParser( new MetadataBuilder().putString(COMMENT.str.toLowerCase, comment).build() case None => Metadata.empty } - StructField(columnName, typ, nullable = true, meta) - } - - protected lazy val primitiveType: Parser[DataType] = - STRING ^^^ StringType | - BINARY ^^^ BinaryType | - BOOLEAN ^^^ BooleanType | - TINYINT ^^^ ByteType | - SMALLINT ^^^ ShortType | - INT ^^^ IntegerType | - BIGINT ^^^ LongType | - FLOAT ^^^ FloatType | - DOUBLE ^^^ DoubleType | - fixedDecimalType | // decimal with precision/scale - DECIMAL ^^^ DecimalType.Unlimited | // decimal with no precision/scale - DATE ^^^ DateType | - TIMESTAMP ^^^ TimestampType | - VARCHAR ~ "(" ~ numericLit ~ ")" ^^^ StringType - - protected lazy val fixedDecimalType: Parser[DataType] = - (DECIMAL ~ "(" ~> numericLit) ~ ("," ~> numericLit <~ ")") ^^ { - case precision ~ scale => DecimalType(precision.toInt, scale.toInt) - } - - protected lazy val arrayType: Parser[DataType] = - ARRAY ~> "<" ~> dataType <~ ">" ^^ { - case tpe => ArrayType(tpe) - } - protected lazy val mapType: Parser[DataType] = - MAP ~> "<" ~> dataType ~ "," ~ dataType <~ ">" ^^ { - case t1 ~ _ ~ t2 => MapType(t1, t2) - } - - protected lazy val structField: Parser[StructField] = - ident ~ ":" ~ dataType ^^ { - case fieldName ~ _ ~ tpe => StructField(fieldName, tpe, nullable = true) + StructField(columnName, typ, nullable = true, meta) } - - protected lazy val structType: Parser[DataType] = - (STRUCT ~> "<" ~> repsep(structField, ",") <~ ">" ^^ { - case fields => StructType(fields) - }) | - (STRUCT ~> "<>" ^^ { - case fields => StructType(Nil) - }) - - private[sql] lazy val dataType: Parser[DataType] = - arrayType | - mapType | - structType | - primitiveType } private[sql] object ResolvedDataSource { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index fe86bd206a71c..949a4e54e6c30 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -756,7 +756,7 @@ private[hive] case class MetastoreRelation implicit class SchemaAttribute(f: FieldSchema) { def toAttribute = AttributeReference( f.getName, - sqlContext.ddlParser.parseType(f.getType), + HiveMetastoreTypes.toDataType(f.getType), // Since data can be dumped in randomly with no validation, everything is nullable. nullable = true )(qualifiers = Seq(alias.getOrElse(tableName))) @@ -779,11 +779,7 @@ private[hive] case class MetastoreRelation private[hive] object HiveMetastoreTypes { - protected val ddlParser = new DDLParser(HiveQl.parseSql(_)) - - def toDataType(metastoreType: String): DataType = synchronized { - ddlParser.parseType(metastoreType) - } + def toDataType(metastoreType: String): DataType = DataTypeParser(metastoreType) def toMetastoreType(dt: DataType): String = dt match { case ArrayType(elementType, _) => s"array<${toMetastoreType(elementType)}>" From b6090f902e6ec24923b4dde4aabc9076956521c1 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sat, 21 Mar 2015 14:30:04 -0700 Subject: [PATCH 30/47] [SPARK-6428][SQL] Added explicit type for all public methods for Hive module Author: Reynold Xin Closes #5108 from rxin/hive-public-type and squashes the following commits: a320328 [Reynold Xin] [SPARK-6428][SQL] Added explicit type for all public methods for Hive module. --- .../hive/thriftserver/SparkSQLCLIDriver.scala | 4 +- .../apache/spark/sql/hive/HiveContext.scala | 6 +-- .../spark/sql/hive/HiveMetastoreCatalog.scala | 21 ++++----- .../org/apache/spark/sql/hive/HiveQl.scala | 10 ++--- .../spark/sql/hive/HiveStrategies.scala | 6 +-- .../apache/spark/sql/hive/TableReader.scala | 2 +- .../hive/execution/CreateTableAsSelect.scala | 2 +- .../execution/DescribeHiveTableCommand.scala | 2 +- .../hive/execution/HiveNativeCommand.scala | 4 +- .../sql/hive/execution/HiveTableScan.scala | 5 ++- .../hive/execution/InsertIntoHiveTable.scala | 4 +- .../hive/execution/ScriptTransformation.scala | 5 ++- .../spark/sql/hive/execution/commands.scala | 8 ++-- .../org/apache/spark/sql/hive/hiveUdfs.scala | 45 +++++++++++-------- .../spark/sql/hive/hiveWriterContainers.scala | 4 +- .../apache/spark/sql/hive/test/TestHive.scala | 13 ++++-- 16 files changed, 79 insertions(+), 62 deletions(-) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala index 895688ab2ec2e..6272cdedb3e48 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala @@ -194,8 +194,8 @@ private[hive] object SparkSQLCLIDriver { val currentDB = ReflectionUtils.invokeStatic(classOf[CliDriver], "getFormattedDb", classOf[HiveConf] -> conf, classOf[CliSessionState] -> sessionState) - def promptWithCurrentDB = s"$prompt$currentDB" - def continuedPromptWithDBSpaces = continuedPrompt + ReflectionUtils.invokeStatic( + def promptWithCurrentDB: String = s"$prompt$currentDB" + def continuedPromptWithDBSpaces: String = continuedPrompt + ReflectionUtils.invokeStatic( classOf[CliDriver], "spacesForString", classOf[String] -> currentDB) var currentPrompt = promptWithCurrentDB diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index a5c435fdfa778..c06c2e396bbc1 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -183,7 +183,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { // Circular buffer to hold what hive prints to STDOUT and ERR. Only printed when failures occur. @transient - protected lazy val outputBuffer = new java.io.OutputStream { + protected lazy val outputBuffer = new java.io.OutputStream { var pos: Int = 0 var buffer = new Array[Int](10240) def write(i: Int): Unit = { @@ -191,7 +191,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { pos = (pos + 1) % buffer.size } - override def toString = { + override def toString: String = { val (end, start) = buffer.splitAt(pos) val input = new java.io.InputStream { val iterator = (start ++ end).iterator @@ -227,7 +227,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { @transient override protected[sql] lazy val functionRegistry = new HiveFunctionRegistry with OverrideFunctionRegistry { - def caseSensitive = false + def caseSensitive: Boolean = false } /* An analyzer that uses the Hive metastore. */ diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 949a4e54e6c30..4c5eb48661f7d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -58,7 +58,7 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with // TODO: Use this everywhere instead of tuples or databaseName, tableName,. /** A fully qualified identifier for a table (i.e., database.tableName) */ case class QualifiedTableName(database: String, name: String) { - def toLowerCase = QualifiedTableName(database.toLowerCase, name.toLowerCase) + def toLowerCase: QualifiedTableName = QualifiedTableName(database.toLowerCase, name.toLowerCase) } /** A cache of Spark SQL data source tables that have been accessed. */ @@ -629,7 +629,8 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with castChildOutput(p, table, child) } - def castChildOutput(p: InsertIntoTable, table: MetastoreRelation, child: LogicalPlan) = { + def castChildOutput(p: InsertIntoTable, table: MetastoreRelation, child: LogicalPlan) + : LogicalPlan = { val childOutputDataTypes = child.output.map(_.dataType) val tableOutputDataTypes = (table.attributes ++ table.partitionKeys).take(child.output.length).map(_.dataType) @@ -667,7 +668,7 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with */ override def unregisterTable(tableIdentifier: Seq[String]): Unit = ??? - override def unregisterAllTables() = {} + override def unregisterAllTables(): Unit = {} } /** @@ -682,10 +683,10 @@ private[hive] case class InsertIntoHiveTable( overwrite: Boolean) extends LogicalPlan { - override def children = child :: Nil - override def output = child.output + override def children: Seq[LogicalPlan] = child :: Nil + override def output: Seq[Attribute] = child.output - override lazy val resolved = childrenResolved && child.output.zip(table.output).forall { + override lazy val resolved: Boolean = childrenResolved && child.output.zip(table.output).forall { case (childAttr, tableAttr) => childAttr.dataType.sameType(tableAttr.dataType) } } @@ -704,13 +705,13 @@ private[hive] case class MetastoreRelation // org.apache.hadoop.hive.ql.metadata.Partition will cause a NotSerializableException // which indicates the SerDe we used is not Serializable. - @transient val hiveQlTable = new Table(table) + @transient val hiveQlTable: Table = new Table(table) - @transient val hiveQlPartitions = partitions.map { p => + @transient val hiveQlPartitions: Seq[Partition] = partitions.map { p => new Partition(hiveQlTable, p) } - @transient override lazy val statistics = Statistics( + @transient override lazy val statistics: Statistics = Statistics( sizeInBytes = { val totalSize = hiveQlTable.getParameters.get(HiveShim.getStatsSetupConstTotalSize) val rawDataSize = hiveQlTable.getParameters.get(HiveShim.getStatsSetupConstRawDataSize) @@ -754,7 +755,7 @@ private[hive] case class MetastoreRelation ) implicit class SchemaAttribute(f: FieldSchema) { - def toAttribute = AttributeReference( + def toAttribute: AttributeReference = AttributeReference( f.getName, HiveMetastoreTypes.toDataType(f.getType), // Since data can be dumped in randomly with no validation, everything is nullable. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index ced99cd082614..51775eb4cd6a0 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -196,8 +196,8 @@ private[hive] object HiveQl { * Right now this function only checks the name, type, text and children of the node * for equality. */ - def checkEquals(other: ASTNode) { - def check(field: String, f: ASTNode => Any) = if (f(n) != f(other)) { + def checkEquals(other: ASTNode): Unit = { + def check(field: String, f: ASTNode => Any): Unit = if (f(n) != f(other)) { sys.error(s"$field does not match for trees. " + s"'${f(n)}' != '${f(other)}' left: ${dumpTree(n)}, right: ${dumpTree(other)}") } @@ -209,7 +209,7 @@ private[hive] object HiveQl { val leftChildren = nilIfEmpty(n.getChildren).asInstanceOf[Seq[ASTNode]] val rightChildren = nilIfEmpty(other.getChildren).asInstanceOf[Seq[ASTNode]] leftChildren zip rightChildren foreach { - case (l,r) => l checkEquals r + case (l, r) => l checkEquals r } } } @@ -269,7 +269,7 @@ private[hive] object HiveQl { } /** Creates LogicalPlan for a given VIEW */ - def createPlanForView(view: Table, alias: Option[String]) = alias match { + def createPlanForView(view: Table, alias: Option[String]): Subquery = alias match { // because hive use things like `_c0` to build the expanded text // currently we cannot support view from "create view v1(c1) as ..." case None => Subquery(view.getTableName, createPlan(view.getViewExpandedText)) @@ -323,7 +323,7 @@ private[hive] object HiveQl { clauses } - def getClause(clauseName: String, nodeList: Seq[Node]) = + def getClause(clauseName: String, nodeList: Seq[Node]): Node = getClauseOption(clauseName, nodeList).getOrElse(sys.error( s"Expected clause $clauseName missing from ${nodeList.map(dumpTree(_)).mkString("\n")}")) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index e63cea60457d9..5f7e897295117 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -58,9 +58,9 @@ private[hive] trait HiveStrategies { @Experimental object ParquetConversion extends Strategy { implicit class LogicalPlanHacks(s: DataFrame) { - def lowerCase = DataFrame(s.sqlContext, s.logicalPlan) + def lowerCase: DataFrame = DataFrame(s.sqlContext, s.logicalPlan) - def addPartitioningAttributes(attrs: Seq[Attribute]) = { + def addPartitioningAttributes(attrs: Seq[Attribute]): DataFrame = { // Don't add the partitioning key if its already present in the data. if (attrs.map(_.name).toSet.subsetOf(s.logicalPlan.output.map(_.name).toSet)) { s @@ -75,7 +75,7 @@ private[hive] trait HiveStrategies { } implicit class PhysicalPlanHacks(originalPlan: SparkPlan) { - def fakeOutput(newOutput: Seq[Attribute]) = + def fakeOutput(newOutput: Seq[Attribute]): OutputFaker = OutputFaker( originalPlan.output.map(a => newOutput.find(a.name.toLowerCase == _.name.toLowerCase) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala index f22c9eaeedc7d..af309c0c6ce2c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala @@ -175,7 +175,7 @@ class HadoopTableReader( relation.partitionKeys.contains(attr) } - def fillPartitionKeys(rawPartValues: Array[String], row: MutableRow) = { + def fillPartitionKeys(rawPartValues: Array[String], row: MutableRow): Unit = { partitionKeyAttrs.foreach { case (attr, ordinal) => val partOrdinal = relation.partitionKeys.indexOf(attr) row(ordinal) = Cast(Literal(rawPartValues(partOrdinal)), attr.dataType).eval(null) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala index a0c91cbc4e86f..fade9e5852eaa 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala @@ -45,7 +45,7 @@ case class CreateTableAsSelect( allowExisting: Boolean, desc: Option[CreateTableDesc]) extends RunnableCommand { - override def run(sqlContext: SQLContext) = { + override def run(sqlContext: SQLContext): Seq[Row] = { val hiveContext = sqlContext.asInstanceOf[HiveContext] lazy val metastoreRelation: MetastoreRelation = { // Create Hive Table diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DescribeHiveTableCommand.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DescribeHiveTableCommand.scala index d0510aa342796..6fce69b58b85e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DescribeHiveTableCommand.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DescribeHiveTableCommand.scala @@ -37,7 +37,7 @@ case class DescribeHiveTableCommand( override val output: Seq[Attribute], isExtended: Boolean) extends RunnableCommand { - override def run(sqlContext: SQLContext) = { + override def run(sqlContext: SQLContext): Seq[Row] = { // Trying to mimic the format of Hive's output. But not exactly the same. var results: Seq[(String, String, String)] = Nil diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveNativeCommand.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveNativeCommand.scala index 9636da206087f..60a9bb630d0d9 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveNativeCommand.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveNativeCommand.scala @@ -26,9 +26,9 @@ import org.apache.spark.sql.types.StringType private[hive] case class HiveNativeCommand(sql: String) extends RunnableCommand { - override def output = + override def output: Seq[AttributeReference] = Seq(AttributeReference("result", StringType, nullable = false)()) - override def run(sqlContext: SQLContext) = + override def run(sqlContext: SQLContext): Seq[Row] = sqlContext.asInstanceOf[HiveContext].runSqlHive(sql).map(Row(_)) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala index 5b3cf2861e8ef..0a5f19eee7105 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala @@ -26,6 +26,7 @@ import org.apache.hadoop.hive.serde2.objectinspector._ import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorCopyOption import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoUtils +import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution._ import org.apache.spark.sql.hive._ @@ -128,11 +129,11 @@ case class HiveTableScan( } } - override def execute() = if (!relation.hiveQlTable.isPartitioned) { + override def execute(): RDD[Row] = if (!relation.hiveQlTable.isPartitioned) { hadoopReader.makeRDDForTable(relation.hiveQlTable) } else { hadoopReader.makeRDDForPartitionedTable(prunePartitions(relation.hiveQlPartitions)) } - override def output = attributes + override def output: Seq[Attribute] = attributes } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index ba5c8e028a151..da53d30354551 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -33,7 +33,7 @@ import org.apache.hadoop.hive.serde2.objectinspector._ import org.apache.hadoop.mapred.{FileOutputCommitter, FileOutputFormat, JobConf} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.expressions.Row +import org.apache.spark.sql.catalyst.expressions.{Attribute, Row} import org.apache.spark.sql.execution.{UnaryNode, SparkPlan} import org.apache.spark.sql.hive._ import org.apache.spark.sql.hive.{ ShimFileSinkDesc => FileSinkDesc} @@ -58,7 +58,7 @@ case class InsertIntoHiveTable( serializer } - def output = child.output + def output: Seq[Attribute] = child.output def saveAsHiveFile( rdd: RDD[Row], diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala index 0c9aee33985bc..8efed7f0299bf 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala @@ -27,6 +27,7 @@ import org.apache.hadoop.hive.serde.serdeConstants import org.apache.hadoop.hive.serde2.AbstractSerDe import org.apache.hadoop.hive.serde2.objectinspector._ +import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.ScriptInputOutputSchema import org.apache.spark.sql.execution._ @@ -51,9 +52,9 @@ case class ScriptTransformation( ioschema: HiveScriptIOSchema)(@transient sc: HiveContext) extends UnaryNode { - override def otherCopyArgs = sc :: Nil + override def otherCopyArgs: Seq[HiveContext] = sc :: Nil - def execute() = { + def execute(): RDD[Row] = { child.execute().mapPartitions { iter => val cmd = List("/bin/bash", "-c", script) val builder = new ProcessBuilder(cmd) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala index 63ad145a6a980..4345ffbf30f77 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala @@ -38,7 +38,7 @@ import org.apache.spark.sql.types.StructType private[hive] case class AnalyzeTable(tableName: String) extends RunnableCommand { - override def run(sqlContext: SQLContext) = { + override def run(sqlContext: SQLContext): Seq[Row] = { sqlContext.asInstanceOf[HiveContext].analyze(tableName) Seq.empty[Row] } @@ -52,7 +52,7 @@ case class DropTable( tableName: String, ifExists: Boolean) extends RunnableCommand { - override def run(sqlContext: SQLContext) = { + override def run(sqlContext: SQLContext): Seq[Row] = { val hiveContext = sqlContext.asInstanceOf[HiveContext] val ifExistsClause = if (ifExists) "IF EXISTS " else "" try { @@ -75,7 +75,7 @@ case class DropTable( private[hive] case class AddJar(path: String) extends RunnableCommand { - override def run(sqlContext: SQLContext) = { + override def run(sqlContext: SQLContext): Seq[Row] = { val hiveContext = sqlContext.asInstanceOf[HiveContext] hiveContext.runSqlHive(s"ADD JAR $path") hiveContext.sparkContext.addJar(path) @@ -86,7 +86,7 @@ case class AddJar(path: String) extends RunnableCommand { private[hive] case class AddFile(path: String) extends RunnableCommand { - override def run(sqlContext: SQLContext) = { + override def run(sqlContext: SQLContext): Seq[Row] = { val hiveContext = sqlContext.asInstanceOf[HiveContext] hiveContext.runSqlHive(s"ADD FILE $path") hiveContext.sparkContext.addFile(path) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala index 4a702d96563d5..bfe43373d9534 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala @@ -45,7 +45,7 @@ import scala.collection.JavaConversions._ private[hive] abstract class HiveFunctionRegistry extends analysis.FunctionRegistry with HiveInspectors { - def getFunctionInfo(name: String) = FunctionRegistry.getFunctionInfo(name) + def getFunctionInfo(name: String): FunctionInfo = FunctionRegistry.getFunctionInfo(name) def lookupFunction(name: String, children: Seq[Expression]): Expression = { // We only look it up to see if it exists, but do not include it in the HiveUDF since it is @@ -78,7 +78,7 @@ private[hive] case class HiveSimpleUdf(funcWrapper: HiveFunctionWrapper, childre type EvaluatedType = Any type UDFType = UDF - def nullable = true + override def nullable: Boolean = true @transient lazy val function = funcWrapper.createFunction[UDFType]() @@ -96,7 +96,7 @@ private[hive] case class HiveSimpleUdf(funcWrapper: HiveFunctionWrapper, childre udfType != null && udfType.deterministic() } - override def foldable = isUDFDeterministic && children.forall(_.foldable) + override def foldable: Boolean = isUDFDeterministic && children.forall(_.foldable) // Create parameter converters @transient @@ -110,7 +110,7 @@ private[hive] case class HiveSimpleUdf(funcWrapper: HiveFunctionWrapper, childre method.getGenericReturnType(), ObjectInspectorOptions.JAVA) @transient - protected lazy val cached = new Array[AnyRef](children.length) + protected lazy val cached: Array[AnyRef] = new Array[AnyRef](children.length) // TODO: Finish input output types. override def eval(input: Row): Any = { @@ -120,17 +120,19 @@ private[hive] case class HiveSimpleUdf(funcWrapper: HiveFunctionWrapper, childre returnInspector) } - override def toString = s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" + override def toString: String = { + s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" + } } // Adapter from Catalyst ExpressionResult to Hive DeferredObject private[hive] class DeferredObjectAdapter(oi: ObjectInspector) extends DeferredObject with HiveInspectors { private var func: () => Any = _ - def set(func: () => Any) { + def set(func: () => Any): Unit = { this.func = func } - override def prepare(i: Int) = {} + override def prepare(i: Int): Unit = {} override def get(): AnyRef = wrap(func(), oi) } @@ -139,7 +141,7 @@ private[hive] case class HiveGenericUdf(funcWrapper: HiveFunctionWrapper, childr type UDFType = GenericUDF type EvaluatedType = Any - def nullable = true + override def nullable: Boolean = true @transient lazy val function = funcWrapper.createFunction[UDFType]() @@ -158,7 +160,7 @@ private[hive] case class HiveGenericUdf(funcWrapper: HiveFunctionWrapper, childr (udfType != null && udfType.deterministic()) } - override def foldable = + override def foldable: Boolean = isUDFDeterministic && returnInspector.isInstanceOf[ConstantObjectInspector] @transient @@ -182,7 +184,9 @@ private[hive] case class HiveGenericUdf(funcWrapper: HiveFunctionWrapper, childr unwrap(function.evaluate(deferedObjects), returnInspector) } - override def toString = s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" + override def toString: String = { + s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" + } } private[hive] case class HiveGenericUdaf( @@ -209,9 +213,11 @@ private[hive] case class HiveGenericUdaf( def nullable: Boolean = true - override def toString = s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" + override def toString: String = { + s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" + } - def newInstance() = new HiveUdafFunction(funcWrapper, children, this) + def newInstance(): HiveUdafFunction = new HiveUdafFunction(funcWrapper, children, this) } /** It is used as a wrapper for the hive functions which uses UDAF interface */ @@ -240,10 +246,11 @@ private[hive] case class HiveUdaf( def nullable: Boolean = true - override def toString = s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" + override def toString: String = { + s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" + } - def newInstance() = - new HiveUdafFunction(funcWrapper, children, this, true) + def newInstance(): HiveUdafFunction = new HiveUdafFunction(funcWrapper, children, this, true) } /** @@ -314,21 +321,23 @@ private[hive] case class HiveGenericUdtf( collected += unwrap(input, outputInspector).asInstanceOf[Row] } - def collectRows() = { + def collectRows(): Seq[Row] = { val toCollect = collected collected = new ArrayBuffer[Row] toCollect } } - override def toString = s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" + override def toString: String = { + s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" + } } /** * Resolve Udtfs Alias. */ private[spark] object ResolveUdtfsAlias extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan) = plan transform { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { case p @ Project(projectList, _) if projectList.exists(_.isInstanceOf[MultiAlias]) && projectList.size != 1 => throw new TreeNodeException(p, "only single Generator supported for SELECT clause") diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala index f136e43acc8f2..ba2bf67aed684 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala @@ -222,7 +222,7 @@ private[spark] class SparkHiveDynamicPartitionWriterContainer( s"/$col=$colString" }.mkString - def newWriter = { + def newWriter(): FileSinkOperator.RecordWriter = { val newFileSinkDesc = new FileSinkDesc( fileSinkConf.getDirName + dynamicPartPath, fileSinkConf.getTableInfo, @@ -246,6 +246,6 @@ private[spark] class SparkHiveDynamicPartitionWriterContainer( Reporter.NULL) } - writers.getOrElseUpdate(dynamicPartPath, newWriter) + writers.getOrElseUpdate(dynamicPartPath, newWriter()) } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index b4aee78046383..dc61e9d2e3522 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -155,8 +155,8 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { protected[hive] class HiveQLQueryExecution(hql: String) extends this.QueryExecution(HiveQl.parseSql(hql)) { - def hiveExec() = runSqlHive(hql) - override def toString = hql + "\n" + super.toString + def hiveExec(): Seq[String] = runSqlHive(hql) + override def toString: String = hql + "\n" + super.toString } /** @@ -186,7 +186,9 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { case class TestTable(name: String, commands: (()=>Unit)*) protected[hive] implicit class SqlCmd(sql: String) { - def cmd = () => new HiveQLQueryExecution(sql).stringResult(): Unit + def cmd: () => Unit = { + () => new HiveQLQueryExecution(sql).stringResult(): Unit + } } /** @@ -194,7 +196,10 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { * demand when a query are run against it. */ lazy val testTables = new mutable.HashMap[String, TestTable]() - def registerTestTable(testTable: TestTable) = testTables += (testTable.name -> testTable) + + def registerTestTable(testTable: TestTable): Unit = { + testTables += (testTable.name -> testTable) + } // The test tables that are defined in the Hive QTestUtil. // /itests/util/src/main/java/org/apache/hadoop/hive/ql/QTestUtil.java From 9b1e1f20d4498bda72dd53a832110883a7ca41b5 Mon Sep 17 00:00:00 2001 From: ypcat Date: Sun, 22 Mar 2015 15:49:13 +0800 Subject: [PATCH 31/47] [SPARK-6408] [SQL] Fix JDBCRDD filtering string literals Author: ypcat Author: Pei-Lun Lee Closes #5087 from ypcat/spark-6408 and squashes the following commits: 1becc16 [ypcat] [SPARK-6408] [SQL] styling 1bc4455 [ypcat] [SPARK-6408] [SQL] move nested function outside e57fa4a [ypcat] [SPARK-6408] [SQL] fix test case 245ab6f [ypcat] [SPARK-6408] [SQL] add test cases for filtering quoted strings 8962534 [Pei-Lun Lee] [SPARK-6408] [SQL] Fix filtering string literals --- .../org/apache/spark/sql/jdbc/JDBCRDD.scala | 19 ++++++++++++++----- .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 12 ++++++++++-- 2 files changed, 24 insertions(+), 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala index 3266b972128ea..76f8593180e85 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.jdbc import java.sql.{Connection, DriverManager, ResultSet, ResultSetMetaData, SQLException} +import org.apache.commons.lang.StringEscapeUtils.escapeSql import org.apache.spark.{Logging, Partition, SparkContext, TaskContext} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions.{Row, SpecificMutableRow} @@ -226,16 +227,24 @@ private[sql] class JDBCRDD( if (sb.length == 0) "1" else sb.substring(1) } + /** + * Converts value to SQL expression. + */ + private def compileValue(value: Any): Any = value match { + case stringValue: String => s"'${escapeSql(stringValue)}'" + case _ => value + } + /** * Turns a single Filter into a String representing a SQL expression. * Returns null for an unhandled filter. */ private def compileFilter(f: Filter): String = f match { - case EqualTo(attr, value) => s"$attr = $value" - case LessThan(attr, value) => s"$attr < $value" - case GreaterThan(attr, value) => s"$attr > $value" - case LessThanOrEqual(attr, value) => s"$attr <= $value" - case GreaterThanOrEqual(attr, value) => s"$attr >= $value" + case EqualTo(attr, value) => s"$attr = ${compileValue(value)}" + case LessThan(attr, value) => s"$attr < ${compileValue(value)}" + case GreaterThan(attr, value) => s"$attr > ${compileValue(value)}" + case LessThanOrEqual(attr, value) => s"$attr <= ${compileValue(value)}" + case GreaterThanOrEqual(attr, value) => s"$attr >= ${compileValue(value)}" case _ => null } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index cd737c0b62767..5eb6ab2e92e8b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -24,6 +24,7 @@ import java.util.{Calendar, GregorianCalendar} import org.apache.spark.sql.test._ import org.scalatest.{FunSuite, BeforeAndAfter} import TestSQLContext._ +import TestSQLContext.implicits._ class JDBCSuite extends FunSuite with BeforeAndAfter { val url = "jdbc:h2:mem:testdb0" @@ -38,7 +39,7 @@ class JDBCSuite extends FunSuite with BeforeAndAfter { conn.prepareStatement("create table test.people (name TEXT(32) NOT NULL, theid INTEGER NOT NULL)").executeUpdate() conn.prepareStatement("insert into test.people values ('fred', 1)").executeUpdate() conn.prepareStatement("insert into test.people values ('mary', 2)").executeUpdate() - conn.prepareStatement("insert into test.people values ('joe', 3)").executeUpdate() + conn.prepareStatement("insert into test.people values ('joe ''foo'' \"bar\"', 3)").executeUpdate() conn.commit() sql( @@ -129,13 +130,20 @@ class JDBCSuite extends FunSuite with BeforeAndAfter { assert(sql("SELECT * FROM foobar WHERE THEID < 1").collect().size == 0) assert(sql("SELECT * FROM foobar WHERE THEID != 2").collect().size == 2) assert(sql("SELECT * FROM foobar WHERE THEID = 1").collect().size == 1) + assert(sql("SELECT * FROM foobar WHERE NAME = 'fred'").collect().size == 1) + assert(sql("SELECT * FROM foobar WHERE NAME > 'fred'").collect().size == 2) + assert(sql("SELECT * FROM foobar WHERE NAME != 'fred'").collect().size == 2) + } + + test("SELECT * WHERE (quoted strings)") { + assert(sql("select * from foobar").where('NAME === "joe 'foo' \"bar\"").collect().size == 1) } test("SELECT first field") { val names = sql("SELECT NAME FROM foobar").collect().map(x => x.getString(0)).sortWith(_ < _) assert(names.size == 3) assert(names(0).equals("fred")) - assert(names(1).equals("joe")) + assert(names(1).equals("joe 'foo' \"bar\"")) assert(names(2).equals("mary")) } From b9fe504b497cfa509310b4045de4873739c76667 Mon Sep 17 00:00:00 2001 From: Ryan Williams Date: Sun, 22 Mar 2015 11:54:23 +0000 Subject: [PATCH 32/47] [SPARK-6448] Make history server log parse exceptions This helped me to debug a parse error that was due to the event log format changing recently. Author: Ryan Williams Closes #5122 from ryan-williams/histerror and squashes the following commits: 5831656 [Ryan Williams] line length c3742ae [Ryan Williams] Make history server log parse exceptions --- .../org/apache/spark/deploy/history/FsHistoryProvider.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index 7fde02040927d..db7c499661319 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -233,7 +233,8 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis } catch { case e: Exception => logError( - s"Exception encountered when attempting to load application log ${fileStatus.getPath}") + s"Exception encountered when attempting to load application log ${fileStatus.getPath}", + e) None } }.toSeq.sortWith(compareAppInfo) From ab4f516fbe63e24e076c68f4933a171a72b6f1fd Mon Sep 17 00:00:00 2001 From: Hangchen Yu Date: Sun, 22 Mar 2015 15:51:10 +0000 Subject: [PATCH 33/47] [SPARK-6455] [docs] Correct some mistakes and typos Correct some typos. Correct a mistake in lib/PageRank.scala. The first PageRank implementation uses standalone Graph interface, but the second uses Pregel interface. It may mislead the code viewers. Author: Hangchen Yu Closes #5128 from yuhc/master and squashes the following commits: 53e5432 [Hangchen Yu] Merge branch 'master' of https://github.com/yuhc/spark 67b77b5 [Hangchen Yu] [SPARK-6455] [docs] Correct some mistakes and typos 206f2dc [Hangchen Yu] Correct some mistakes and typos. --- graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala | 4 ++-- graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala | 4 ++-- .../src/main/scala/org/apache/spark/graphx/lib/PageRank.scala | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala b/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala index dc8b4789c4b61..86f611d55aa8a 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala @@ -113,7 +113,7 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali * Collect the neighbor vertex attributes for each vertex. * * @note This function could be highly inefficient on power-law - * graphs where high degree vertices may force a large ammount of + * graphs where high degree vertices may force a large amount of * information to be collected to a single location. * * @param edgeDirection the direction along which to collect @@ -187,7 +187,7 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali /** * Join the vertices with an RDD and then apply a function from the - * the vertex and RDD entry to a new vertex value. The input table + * vertex and RDD entry to a new vertex value. The input table * should contain at most one entry for each vertex. If no entry is * provided the map function is skipped and the old value is used. * diff --git a/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala b/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala index 5e55620147df8..01b013ff716fc 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala @@ -78,8 +78,8 @@ object Pregel extends Logging { * * @param graph the input graph. * - * @param initialMsg the message each vertex will receive at the on - * the first iteration + * @param initialMsg the message each vertex will receive at the first + * iteration * * @param maxIterations the maximum number of iterations to run for * diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala index e139959c3f5c1..ca3b513821e13 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala @@ -25,7 +25,7 @@ import org.apache.spark.graphx._ /** * PageRank algorithm implementation. There are two implementations of PageRank implemented. * - * The first implementation uses the [[Pregel]] interface and runs PageRank for a fixed number + * The first implementation uses the standalone [[Graph]] interface and runs PageRank for a fixed number * of iterations: * {{{ * var PR = Array.fill(n)( 1.0 ) @@ -38,7 +38,7 @@ import org.apache.spark.graphx._ * } * }}} * - * The second implementation uses the standalone [[Graph]] interface and runs PageRank until + * The second implementation uses the [[Pregel]] interface and runs PageRank until * convergence: * * {{{ From adb2ff752fa8bda54c969b60a3168d87cd70237d Mon Sep 17 00:00:00 2001 From: Jongyoul Lee Date: Sun, 22 Mar 2015 15:53:18 +0000 Subject: [PATCH 34/47] [SPARK-6453][Mesos] Some Mesos*Suite have a different package with their classes - Moved Suites from o.a.s.s.mesos to o.a.s.s.cluster.mesos Author: Jongyoul Lee Closes #5126 from jongyoul/SPARK-6453 and squashes the following commits: 4f24a3e [Jongyoul Lee] [SPARK-6453][Mesos] Some Mesos*Suite have a different package with their classes - Fixed imports orders 8ab149d [Jongyoul Lee] [SPARK-6453][Mesos] Some Mesos*Suite have a different package with their classes - Moved Suites from o.a.s.s.mesos to o.a.s.s.cluster.mesos --- .../mesos/MesosSchedulerBackendSuite.scala | 13 ++++++------- .../mesos/MesosTaskLaunchDataSuite.scala | 4 +--- 2 files changed, 7 insertions(+), 10 deletions(-) rename core/src/test/scala/org/apache/spark/scheduler/{ => cluster}/mesos/MesosSchedulerBackendSuite.scala (98%) rename core/src/test/scala/org/apache/spark/scheduler/{ => cluster}/mesos/MesosTaskLaunchDataSuite.scala (92%) diff --git a/core/src/test/scala/org/apache/spark/scheduler/mesos/MesosSchedulerBackendSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala similarity index 98% rename from core/src/test/scala/org/apache/spark/scheduler/mesos/MesosSchedulerBackendSuite.scala rename to core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala index afbaa9ade811f..f1a4380d349b3 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/mesos/MesosSchedulerBackendSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.scheduler.mesos +package org.apache.spark.scheduler.cluster.mesos import java.nio.ByteBuffer import java.util @@ -24,21 +24,20 @@ import java.util.Collections import scala.collection.mutable import scala.collection.mutable.ArrayBuffer -import org.apache.mesos.SchedulerDriver -import org.apache.mesos.Protos._ import org.apache.mesos.Protos.Value.Scalar -import org.mockito.Mockito._ +import org.apache.mesos.Protos._ +import org.apache.mesos.SchedulerDriver import org.mockito.Matchers._ +import org.mockito.Mockito._ import org.mockito.{ArgumentCaptor, Matchers} import org.scalatest.FunSuite import org.scalatest.mock.MockitoSugar -import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext} import org.apache.spark.executor.MesosExecutorBackend +import org.apache.spark.scheduler.cluster.ExecutorInfo import org.apache.spark.scheduler.{LiveListenerBus, SparkListenerExecutorAdded, TaskDescription, TaskSchedulerImpl, WorkerOffer} -import org.apache.spark.scheduler.cluster.ExecutorInfo -import org.apache.spark.scheduler.cluster.mesos.{MesosSchedulerBackend, MemoryUtils} +import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext} class MesosSchedulerBackendSuite extends FunSuite with LocalSparkContext with MockitoSugar { diff --git a/core/src/test/scala/org/apache/spark/scheduler/mesos/MesosTaskLaunchDataSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosTaskLaunchDataSuite.scala similarity index 92% rename from core/src/test/scala/org/apache/spark/scheduler/mesos/MesosTaskLaunchDataSuite.scala rename to core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosTaskLaunchDataSuite.scala index 86a42a7398e4d..eebcba40f8a1c 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/mesos/MesosTaskLaunchDataSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosTaskLaunchDataSuite.scala @@ -15,14 +15,12 @@ * limitations under the License. */ -package org.apache.spark.scheduler.mesos +package org.apache.spark.scheduler.cluster.mesos import java.nio.ByteBuffer import org.scalatest.FunSuite -import org.apache.spark.scheduler.cluster.mesos.MesosTaskLaunchData - class MesosTaskLaunchDataSuite extends FunSuite { test("serialize and deserialize data must be same") { val serializedTask = ByteBuffer.allocate(40) From 6ef48632fbf3e6659ceacaab1dbb8be8238d4d33 Mon Sep 17 00:00:00 2001 From: Kamil Smuga Date: Sun, 22 Mar 2015 15:56:25 +0000 Subject: [PATCH 35/47] SPARK-6454 [DOCS] Fix links to pyspark api Author: Kamil Smuga Author: stderr Closes #5120 from kamilsmuga/master and squashes the following commits: fee3281 [Kamil Smuga] more python api links fixed for docs 13240cb [Kamil Smuga] resolved merge conflicts with upstream/master 6649b3b [Kamil Smuga] fix broken docs links to Python API 92f03d7 [stderr] Fix links to pyspark api --- docs/mllib-data-types.md | 8 ++++---- docs/mllib-naive-bayes.md | 6 +++--- docs/mllib-statistics.md | 10 +++++----- docs/programming-guide.md | 12 ++++++------ docs/sql-programming-guide.md | 2 +- 5 files changed, 19 insertions(+), 19 deletions(-) diff --git a/docs/mllib-data-types.md b/docs/mllib-data-types.md index fe6c1bf7bfd99..4f2a2f71048f7 100644 --- a/docs/mllib-data-types.md +++ b/docs/mllib-data-types.md @@ -78,13 +78,13 @@ MLlib recognizes the following types as dense vectors: and the following as sparse vectors: -* MLlib's [`SparseVector`](api/python/pyspark.mllib.linalg.SparseVector-class.html). +* MLlib's [`SparseVector`](api/python/pyspark.mllib.html#pyspark.mllib.linalg.SparseVector). * SciPy's [`csc_matrix`](http://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.csc_matrix.html#scipy.sparse.csc_matrix) with a single column We recommend using NumPy arrays over lists for efficiency, and using the factory methods implemented -in [`Vectors`](api/python/pyspark.mllib.linalg.Vectors-class.html) to create sparse vectors. +in [`Vectors`](api/python/pyspark.mllib.html#pyspark.mllib.linalg.Vector) to create sparse vectors. {% highlight python %} import numpy as np @@ -151,7 +151,7 @@ LabeledPoint neg = new LabeledPoint(1.0, Vectors.sparse(3, new int[] {0, 2}, new
A labeled point is represented by -[`LabeledPoint`](api/python/pyspark.mllib.regression.LabeledPoint-class.html). +[`LabeledPoint`](api/python/pyspark.mllib.html#pyspark.mllib.regression.LabeledPoint). {% highlight python %} from pyspark.mllib.linalg import SparseVector @@ -211,7 +211,7 @@ JavaRDD examples =
-[`MLUtils.loadLibSVMFile`](api/python/pyspark.mllib.util.MLUtils-class.html) reads training +[`MLUtils.loadLibSVMFile`](api/python/pyspark.mllib.html#pyspark.mllib.util.MLUtils) reads training examples stored in LIBSVM format. {% highlight python %} diff --git a/docs/mllib-naive-bayes.md b/docs/mllib-naive-bayes.md index 55b8f2ce6c364..a83472f5be52e 100644 --- a/docs/mllib-naive-bayes.md +++ b/docs/mllib-naive-bayes.md @@ -106,11 +106,11 @@ NaiveBayesModel sameModel = NaiveBayesModel.load(sc.sc(), "myModelPath");
-[NaiveBayes](api/python/pyspark.mllib.classification.NaiveBayes-class.html) implements multinomial +[NaiveBayes](api/python/pyspark.mllib.html#pyspark.mllib.classification.NaiveBayes) implements multinomial naive Bayes. It takes an RDD of -[LabeledPoint](api/python/pyspark.mllib.regression.LabeledPoint-class.html) and an optionally +[LabeledPoint](api/python/pyspark.mllib.html#pyspark.mllib.regression.LabeledPoint) and an optionally smoothing parameter `lambda` as input, and output a -[NaiveBayesModel](api/python/pyspark.mllib.classification.NaiveBayesModel-class.html), which can be +[NaiveBayesModel](api/python/pyspark.mllib.html#pyspark.mllib.classification.NaiveBayesModel), which can be used for evaluation and prediction. Note that the Python API does not yet support model save/load but will in the future. diff --git a/docs/mllib-statistics.md b/docs/mllib-statistics.md index ca8c29218f52d..887eae7f4f07b 100644 --- a/docs/mllib-statistics.md +++ b/docs/mllib-statistics.md @@ -81,8 +81,8 @@ System.out.println(summary.numNonzeros()); // number of nonzeros in each column
-[`colStats()`](api/python/pyspark.mllib.stat.Statistics-class.html#colStats) returns an instance of -[`MultivariateStatisticalSummary`](api/python/pyspark.mllib.stat.MultivariateStatisticalSummary-class.html), +[`colStats()`](api/python/pyspark.mllib.html#pyspark.mllib.stat.Statistics.colStats) returns an instance of +[`MultivariateStatisticalSummary`](api/python/pyspark.mllib.html#pyspark.mllib.stat.MultivariateStatisticalSummary), which contains the column-wise max, min, mean, variance, and number of nonzeros, as well as the total count. @@ -169,7 +169,7 @@ Matrix correlMatrix = Statistics.corr(data.rdd(), "pearson");
-[`Statistics`](api/python/pyspark.mllib.stat.Statistics-class.html) provides methods to +[`Statistics`](api/python/pyspark.mllib.html#pyspark.mllib.stat.Statistics) provides methods to calculate correlations between series. Depending on the type of input, two `RDD[Double]`s or an `RDD[Vector]`, the output will be a `Double` or the correlation `Matrix` respectively. @@ -258,7 +258,7 @@ JavaPairRDD exactSample = data.sampleByKeyExact(false, fractions); {% endhighlight %}
-[`sampleByKey()`](api/python/pyspark.rdd.RDD-class.html#sampleByKey) allows users to +[`sampleByKey()`](api/python/pyspark.html#pyspark.RDD.sampleByKey) allows users to sample approximately $\lceil f_k \cdot n_k \rceil \, \forall k \in K$ items, where $f_k$ is the desired fraction for key $k$, $n_k$ is the number of key-value pairs for key $k$, and $K$ is the set of keys. @@ -476,7 +476,7 @@ JavaDoubleRDD v = u.map(
-[`RandomRDDs`](api/python/pyspark.mllib.random.RandomRDDs-class.html) provides factory +[`RandomRDDs`](api/python/pyspark.mllib.html#pyspark.mllib.random.RandomRDDs) provides factory methods to generate random double RDDs or vector RDDs. The following example generates a random double RDD, whose values follows the standard normal distribution `N(0, 1)`, and then map it to `N(1, 4)`. diff --git a/docs/programming-guide.md b/docs/programming-guide.md index eda3a95426182..5fe832b6fa100 100644 --- a/docs/programming-guide.md +++ b/docs/programming-guide.md @@ -142,8 +142,8 @@ JavaSparkContext sc = new JavaSparkContext(conf);
-The first thing a Spark program must do is to create a [SparkContext](api/python/pyspark.context.SparkContext-class.html) object, which tells Spark -how to access a cluster. To create a `SparkContext` you first need to build a [SparkConf](api/python/pyspark.conf.SparkConf-class.html) object +The first thing a Spark program must do is to create a [SparkContext](api/python/pyspark.html#pyspark.SparkContext) object, which tells Spark +how to access a cluster. To create a `SparkContext` you first need to build a [SparkConf](api/python/pyspark.html#pyspark.SparkConf) object that contains information about your application. {% highlight python %} @@ -912,7 +912,7 @@ The following table lists some of the common transformations supported by Spark. RDD API doc ([Scala](api/scala/index.html#org.apache.spark.rdd.RDD), [Java](api/java/index.html?org/apache/spark/api/java/JavaRDD.html), - [Python](api/python/pyspark.rdd.RDD-class.html)) + [Python](api/python/pyspark.html#pyspark.RDD)) and pair RDD functions doc ([Scala](api/scala/index.html#org.apache.spark.rdd.PairRDDFunctions), [Java](api/java/index.html?org/apache/spark/api/java/JavaPairRDD.html)) @@ -1025,7 +1025,7 @@ The following table lists some of the common actions supported by Spark. Refer t RDD API doc ([Scala](api/scala/index.html#org.apache.spark.rdd.RDD), [Java](api/java/index.html?org/apache/spark/api/java/JavaRDD.html), - [Python](api/python/pyspark.rdd.RDD-class.html)) + [Python](api/python/pyspark.html#pyspark.RDD)) and pair RDD functions doc ([Scala](api/scala/index.html#org.apache.spark.rdd.PairRDDFunctions), [Java](api/java/index.html?org/apache/spark/api/java/JavaPairRDD.html)) @@ -1105,7 +1105,7 @@ replicate it across nodes, or store it off-heap in [Tachyon](http://tachyon-proj These levels are set by passing a `StorageLevel` object ([Scala](api/scala/index.html#org.apache.spark.storage.StorageLevel), [Java](api/java/index.html?org/apache/spark/storage/StorageLevel.html), -[Python](api/python/pyspark.storagelevel.StorageLevel-class.html)) +[Python](api/python/pyspark.html#pyspark.StorageLevel)) to `persist()`. The `cache()` method is a shorthand for using the default storage level, which is `StorageLevel.MEMORY_ONLY` (store deserialized objects in memory). The full set of storage levels is: @@ -1374,7 +1374,7 @@ scala> accum.value {% endhighlight %} While this code used the built-in support for accumulators of type Int, programmers can also -create their own types by subclassing [AccumulatorParam](api/python/pyspark.accumulators.AccumulatorParam-class.html). +create their own types by subclassing [AccumulatorParam](api/python/pyspark.html#pyspark.AccumulatorParam). The AccumulatorParam interface has two methods: `zero` for providing a "zero value" for your data type, and `addInPlace` for adding two values together. For example, supposing we had a `Vector` class representing mathematical vectors, we could write: diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 2cbb4c967eb81..a7d35741a48c3 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -56,7 +56,7 @@ SQLContext sqlContext = new org.apache.spark.sql.SQLContext(sc);
The entry point into all relational functionality in Spark is the -[`SQLContext`](api/python/pyspark.sql.SQLContext-class.html) class, or one +[`SQLContext`](api/python/pyspark.sql.html#pyspark.sql.SQLContext) class, or one of its decedents. To create a basic `SQLContext`, all you need is a SparkContext. {% highlight python %} From a41b9c6004cfee84bd56dfa1faf5a0cf084551ae Mon Sep 17 00:00:00 2001 From: Calvin Jia Date: Sun, 22 Mar 2015 11:11:29 -0700 Subject: [PATCH 36/47] [SPARK-6122][Core] Upgrade Tachyon client version to 0.6.1. Changes the Tachyon client version from 0.5 to 0.6 in spark core and distribution script. New dependencies in Tachyon 0.6.0 include commons-codec:commons-codec:jar:1.5:compile io.netty:netty-all:jar:4.0.23.Final:compile These are already in spark core. Author: Calvin Jia Closes #4867 from calvinjia/upgrade_tachyon_0.6.0 and squashes the following commits: eed9230 [Calvin Jia] Update tachyon version to 0.6.1. 11907b3 [Calvin Jia] Use TachyonURI for tachyon paths instead of strings. 71bf441 [Calvin Jia] Upgrade Tachyon client version to 0.6.0. --- core/pom.xml | 2 +- .../spark/storage/TachyonBlockManager.scala | 27 +++++++++---------- .../scala/org/apache/spark/util/Utils.scala | 4 ++- make-distribution.sh | 2 +- 4 files changed, 18 insertions(+), 17 deletions(-) diff --git a/core/pom.xml b/core/pom.xml index 6cd1965ec37c2..868834dd505ef 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -275,7 +275,7 @@ org.tachyonproject tachyon-client - 0.5.0 + 0.6.1 org.apache.hadoop diff --git a/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala index af873034215a9..2ab6a8f3ec1d4 100644 --- a/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala @@ -20,8 +20,8 @@ package org.apache.spark.storage import java.text.SimpleDateFormat import java.util.{Date, Random} -import tachyon.client.TachyonFS -import tachyon.client.TachyonFile +import tachyon.TachyonURI +import tachyon.client.{TachyonFile, TachyonFS} import org.apache.spark.Logging import org.apache.spark.executor.ExecutorExitCode @@ -40,7 +40,7 @@ private[spark] class TachyonBlockManager( val master: String) extends Logging { - val client = if (master != null && master != "") TachyonFS.get(master) else null + val client = if (master != null && master != "") TachyonFS.get(new TachyonURI(master)) else null if (client == null) { logError("Failed to connect to the Tachyon as the master address is not configured") @@ -60,11 +60,11 @@ private[spark] class TachyonBlockManager( addShutdownHook() def removeFile(file: TachyonFile): Boolean = { - client.delete(file.getPath(), false) + client.delete(new TachyonURI(file.getPath()), false) } def fileExists(file: TachyonFile): Boolean = { - client.exist(file.getPath()) + client.exist(new TachyonURI(file.getPath())) } def getFile(filename: String): TachyonFile = { @@ -81,7 +81,7 @@ private[spark] class TachyonBlockManager( if (old != null) { old } else { - val path = tachyonDirs(dirId) + "/" + "%02x".format(subDirId) + val path = new TachyonURI(s"${tachyonDirs(dirId)}/${"%02x".format(subDirId)}") client.mkdir(path) val newDir = client.getFile(path) subDirs(dirId)(subDirId) = newDir @@ -89,7 +89,7 @@ private[spark] class TachyonBlockManager( } } } - val filePath = subDir + "/" + filename + val filePath = new TachyonURI(s"$subDir/$filename") if(!client.exist(filePath)) { client.createFile(filePath) } @@ -101,7 +101,7 @@ private[spark] class TachyonBlockManager( // TODO: Some of the logic here could be consolidated/de-duplicated with that in the DiskStore. private def createTachyonDirs(): Array[TachyonFile] = { - logDebug("Creating tachyon directories at root dirs '" + rootDirs + "'") + logDebug(s"Creating tachyon directories at root dirs '$rootDirs'") val dateFormat = new SimpleDateFormat("yyyyMMddHHmmss") rootDirs.split(",").map { rootDir => var foundLocalDir = false @@ -113,22 +113,21 @@ private[spark] class TachyonBlockManager( tries += 1 try { tachyonDirId = "%s-%04x".format(dateFormat.format(new Date), rand.nextInt(65536)) - val path = rootDir + "/" + "spark-tachyon-" + tachyonDirId + val path = new TachyonURI(s"$rootDir/spark-tachyon-$tachyonDirId") if (!client.exist(path)) { foundLocalDir = client.mkdir(path) tachyonDir = client.getFile(path) } } catch { case e: Exception => - logWarning("Attempt " + tries + " to create tachyon dir " + tachyonDir + " failed", e) + logWarning(s"Attempt $tries to create tachyon dir $tachyonDir failed", e) } } if (!foundLocalDir) { - logError("Failed " + MAX_DIR_CREATION_ATTEMPTS + " attempts to create tachyon dir in " + - rootDir) + logError(s"Failed $MAX_DIR_CREATION_ATTEMPTS attempts to create tachyon dir in $rootDir") System.exit(ExecutorExitCode.TACHYON_STORE_FAILED_TO_CREATE_DIR) } - logInfo("Created tachyon directory at " + tachyonDir) + logInfo(s"Created tachyon directory at $tachyonDir") tachyonDir } } @@ -145,7 +144,7 @@ private[spark] class TachyonBlockManager( } } catch { case e: Exception => - logError("Exception while deleting tachyon spark dir: " + tachyonDir, e) + logError(s"Exception while deleting tachyon spark dir: $tachyonDir", e) } } client.close() diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index fa56bb09e4e5c..91d833295e376 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -42,6 +42,8 @@ import org.apache.hadoop.security.UserGroupInformation import org.apache.log4j.PropertyConfigurator import org.eclipse.jetty.util.MultiException import org.json4s._ + +import tachyon.TachyonURI import tachyon.client.{TachyonFS, TachyonFile} import org.apache.spark._ @@ -970,7 +972,7 @@ private[spark] object Utils extends Logging { * Delete a file or directory and its contents recursively. */ def deleteRecursively(dir: TachyonFile, client: TachyonFS) { - if (!client.delete(dir.getPath(), true)) { + if (!client.delete(new TachyonURI(dir.getPath()), true)) { throw new IOException("Failed to delete the tachyon dir: " + dir) } } diff --git a/make-distribution.sh b/make-distribution.sh index 9ed1abfe8c598..8162fe94c1af0 100755 --- a/make-distribution.sh +++ b/make-distribution.sh @@ -32,7 +32,7 @@ SPARK_HOME="$(cd "`dirname "$0"`"; pwd)" DISTDIR="$SPARK_HOME/dist" SPARK_TACHYON=false -TACHYON_VERSION="0.5.0" +TACHYON_VERSION="0.6.1" TACHYON_TGZ="tachyon-${TACHYON_VERSION}-bin.tar.gz" TACHYON_URL="https://github.com/amplab/tachyon/releases/download/v${TACHYON_VERSION}/${TACHYON_TGZ}" From 7a0da47708b0e6b117b5c1a35aa3e93b8a914d5f Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sun, 22 Mar 2015 12:08:15 -0700 Subject: [PATCH 37/47] [HOTFIX] Build break due to https://github.com/apache/spark/pull/5128 --- .../src/main/scala/org/apache/spark/graphx/lib/PageRank.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala index ca3b513821e13..570440ba4441f 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala @@ -25,8 +25,8 @@ import org.apache.spark.graphx._ /** * PageRank algorithm implementation. There are two implementations of PageRank implemented. * - * The first implementation uses the standalone [[Graph]] interface and runs PageRank for a fixed number - * of iterations: + * The first implementation uses the standalone [[Graph]] interface and runs PageRank + * for a fixed number of iterations: * {{{ * var PR = Array.fill(n)( 1.0 ) * val oldPR = Array.fill(n)( 1.0 ) From 2bf40c58e6e89e061783c999204107069df17f73 Mon Sep 17 00:00:00 2001 From: vinodkc Date: Sun, 22 Mar 2015 20:00:08 +0000 Subject: [PATCH 38/47] [SPARK-6337][Documentation, SQL]Spark 1.3 doc fixes Author: vinodkc Closes #5112 from vinodkc/spark_1.3_doc_fixes and squashes the following commits: 2c6aee6 [vinodkc] Spark 1.3 doc fixes --- docs/sql-programming-guide.md | 7 +++++-- mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala | 2 +- .../src/main/scala/org/apache/spark/sql/DataFrame.scala | 2 +- 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index a7d35741a48c3..6a333fdb562a7 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -509,8 +509,11 @@ val people = sc.textFile("examples/src/main/resources/people.txt") // The schema is encoded in a string val schemaString = "name age" -// Import Spark SQL data types and Row. -import org.apache.spark.sql._ +// Import Row. +import org.apache.spark.sql.Row; + +// Import Spark SQL data types +import org.apache.spark.sql.types.{StructType,StructField,StringType}; // Generate the schema based on the string of schema val schema = diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala index 5bbcd2e080e07..c4a36103303a2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.types.StructType abstract class PipelineStage extends Serializable with Logging { /** - * :: DeveloperAPI :: + * :: DeveloperApi :: * * Derives the output schema from the input schema and parameters. * The schema describes the columns and types of the data. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 8b8f86c4127e0..5aece166aad22 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -89,7 +89,7 @@ private[sql] object DataFrame { * val people = sqlContext.parquetFile("...") * val department = sqlContext.parquetFile("...") * - * people.filter("age" > 30) + * people.filter("age > 30") * .join(department, people("deptId") === department("id")) * .groupBy(department("name"), "gender") * .agg(avg(people("salary")), max(people("age"))) From 4659468f369d69e7f777130e5e3b4c5d47a624f1 Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Mon, 23 Mar 2015 11:46:16 +0800 Subject: [PATCH 39/47] [SPARK-4985] [SQL] parquet support for date type This PR might have some issues with #3732 , and this would have merge conflicts with #3820 so the review can be delayed till that 2 were merged. Author: Daoyuan Wang Closes #3822 from adrian-wang/parquetdate and squashes the following commits: 2c5d54d [Daoyuan Wang] add a test case faef887 [Daoyuan Wang] parquet support for primitive date 97e9080 [Daoyuan Wang] parquet support for date type --- .../spark/sql/parquet/ParquetConverter.scala | 12 ++++++++++++ .../spark/sql/parquet/ParquetTableSupport.scala | 2 ++ .../apache/spark/sql/parquet/ParquetTypes.scala | 4 ++++ .../apache/spark/sql/parquet/ParquetIOSuite.scala | 15 +++++++++++++++ .../spark/sql/parquet/ParquetSchemaSuite.scala | 3 ++- 5 files changed, 35 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala index f898e4b37a56b..43ca359b51735 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala @@ -127,6 +127,12 @@ private[sql] object CatalystConverter { parent.updateByte(fieldIndex, value.asInstanceOf[ByteType.JvmType]) } } + case DateType => { + new CatalystPrimitiveConverter(parent, fieldIndex) { + override def addInt(value: Int): Unit = + parent.updateDate(fieldIndex, value.asInstanceOf[DateType.JvmType]) + } + } case d: DecimalType => { new CatalystPrimitiveConverter(parent, fieldIndex) { override def addBinary(value: Binary): Unit = @@ -192,6 +198,9 @@ private[parquet] abstract class CatalystConverter extends GroupConverter { protected[parquet] def updateInt(fieldIndex: Int, value: Int): Unit = updateField(fieldIndex, value) + protected[parquet] def updateDate(fieldIndex: Int, value: Int): Unit = + updateField(fieldIndex, value) + protected[parquet] def updateLong(fieldIndex: Int, value: Long): Unit = updateField(fieldIndex, value) @@ -388,6 +397,9 @@ private[parquet] class CatalystPrimitiveRowConverter( override protected[parquet] def updateInt(fieldIndex: Int, value: Int): Unit = current.setInt(fieldIndex, value) + override protected[parquet] def updateDate(fieldIndex: Int, value: Int): Unit = + current.update(fieldIndex, value) + override protected[parquet] def updateLong(fieldIndex: Int, value: Long): Unit = current.setLong(fieldIndex, value) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala index 19bfba34b8f4a..5a1b15490d273 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala @@ -212,6 +212,7 @@ private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging { case DoubleType => writer.addDouble(value.asInstanceOf[Double]) case FloatType => writer.addFloat(value.asInstanceOf[Float]) case BooleanType => writer.addBoolean(value.asInstanceOf[Boolean]) + case DateType => writer.addInteger(value.asInstanceOf[Int]) case d: DecimalType => if (d.precisionInfo == None || d.precisionInfo.get.precision > 18) { sys.error(s"Unsupported datatype $d, cannot write to consumer") @@ -358,6 +359,7 @@ private[parquet] class MutableRowWriteSupport extends RowWriteSupport { case DoubleType => writer.addDouble(record.getDouble(index)) case FloatType => writer.addFloat(record.getFloat(index)) case BooleanType => writer.addBoolean(record.getBoolean(index)) + case DateType => writer.addInteger(record.getInt(index)) case TimestampType => writeTimestamp(record(index).asInstanceOf[java.sql.Timestamp]) case d: DecimalType => if (d.precisionInfo == None || d.precisionInfo.get.precision > 18) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala index 5209581fa8357..da668f068613b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala @@ -64,6 +64,8 @@ private[parquet] object ParquetTypesConverter extends Logging { case ParquetPrimitiveTypeName.BOOLEAN => BooleanType case ParquetPrimitiveTypeName.DOUBLE => DoubleType case ParquetPrimitiveTypeName.FLOAT => FloatType + case ParquetPrimitiveTypeName.INT32 + if originalType == ParquetOriginalType.DATE => DateType case ParquetPrimitiveTypeName.INT32 => IntegerType case ParquetPrimitiveTypeName.INT64 => LongType case ParquetPrimitiveTypeName.INT96 if int96AsTimestamp => TimestampType @@ -222,6 +224,8 @@ private[parquet] object ParquetTypesConverter extends Logging { // There is no type for Byte or Short so we promote them to INT32. case ShortType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.INT32)) case ByteType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.INT32)) + case DateType => Some(ParquetTypeInfo( + ParquetPrimitiveTypeName.INT32, Some(ParquetOriginalType.DATE))) case LongType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.INT64)) case TimestampType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.INT96)) case DecimalType.Fixed(precision, scale) if precision <= 18 => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala index 5438095addeaf..203bc79f153dd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala @@ -135,6 +135,21 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { } } + test("date type") { + def makeDateRDD(): DataFrame = + sparkContext + .parallelize(0 to 1000) + .map(i => Tuple1(DateUtils.toJavaDate(i))) + .toDF() + .select($"_1") + + withTempPath { dir => + val data = makeDateRDD() + data.saveAsParquetFile(dir.getCanonicalPath) + checkAnswer(parquetFile(dir.getCanonicalPath), data.collect().toSeq) + } + } + test("map") { val data = (1 to 4).map(i => Tuple1(Map(i -> s"val_$i"))) checkParquetFile(data) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala index ad880e2bc3679..321832cd43211 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala @@ -57,7 +57,7 @@ class ParquetSchemaSuite extends FunSuite with ParquetTest { |} """.stripMargin) - testSchema[(Byte, Short, Int, Long)]( + testSchema[(Byte, Short, Int, Long, java.sql.Date)]( "logical integral types", """ |message root { @@ -65,6 +65,7 @@ class ParquetSchemaSuite extends FunSuite with ParquetTest { | required int32 _2 (INT_16); | required int32 _3 (INT_32); | required int64 _4 (INT_64); + | optional int32 _5 (DATE); |} """.stripMargin) From e566fe5982bac5d24e6be76e5d7d6270544a85e6 Mon Sep 17 00:00:00 2001 From: q00251598 Date: Mon, 23 Mar 2015 12:06:13 +0800 Subject: [PATCH 40/47] [SPARK-6397][SQL] Check the missingInput simply Author: q00251598 Closes #5082 from watermen/sql-missingInput and squashes the following commits: 25766b9 [q00251598] Check the missingInput simply --- .../apache/spark/sql/catalyst/analysis/CheckAnalysis.scala | 5 ++--- .../spark/sql/catalyst/plans/logical/basicOperators.scala | 2 ++ 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 4e8fc892f3eea..fb975ee5e7296 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -85,9 +85,8 @@ class CheckAnalysis { cleaned.foreach(checkValidAggregateExpression) - case o if o.children.nonEmpty && - !o.references.filter(_.name != "grouping__id").subsetOf(o.inputSet) => - val missingAttributes = (o.references -- o.inputSet).map(_.prettyString).mkString(",") + case o if o.children.nonEmpty && o.missingInput.nonEmpty => + val missingAttributes = o.missingInput.map(_.prettyString).mkString(",") val input = o.inputSet.map(_.prettyString).mkString(",") failAnalysis(s"resolved attributes $missingAttributes missing from $input") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 384fe53a68362..a94b2d2095d12 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -191,6 +191,8 @@ case class Expand( val sizeInBytes = child.statistics.sizeInBytes * projections.length Statistics(sizeInBytes = sizeInBytes) } + + override def missingInput = super.missingInput.filter(_.name != VirtualColumn.groupingIdName) } trait GroupingAnalytics extends UnaryNode { From bf044def4c3a37a0fd4d5e70c2d57685cfd9fd71 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Mon, 23 Mar 2015 12:15:19 +0800 Subject: [PATCH 41/47] Revert "[SPARK-6397][SQL] Check the missingInput simply" This reverts commit e566fe5982bac5d24e6be76e5d7d6270544a85e6. --- .../apache/spark/sql/catalyst/analysis/CheckAnalysis.scala | 5 +++-- .../spark/sql/catalyst/plans/logical/basicOperators.scala | 2 -- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index fb975ee5e7296..4e8fc892f3eea 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -85,8 +85,9 @@ class CheckAnalysis { cleaned.foreach(checkValidAggregateExpression) - case o if o.children.nonEmpty && o.missingInput.nonEmpty => - val missingAttributes = o.missingInput.map(_.prettyString).mkString(",") + case o if o.children.nonEmpty && + !o.references.filter(_.name != "grouping__id").subsetOf(o.inputSet) => + val missingAttributes = (o.references -- o.inputSet).map(_.prettyString).mkString(",") val input = o.inputSet.map(_.prettyString).mkString(",") failAnalysis(s"resolved attributes $missingAttributes missing from $input") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index a94b2d2095d12..384fe53a68362 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -191,8 +191,6 @@ case class Expand( val sizeInBytes = child.statistics.sizeInBytes * projections.length Statistics(sizeInBytes = sizeInBytes) } - - override def missingInput = super.missingInput.filter(_.name != VirtualColumn.groupingIdName) } trait GroupingAnalytics extends UnaryNode { From 9f3273bd9c919f6c48a95383b3d5be357c89998c Mon Sep 17 00:00:00 2001 From: Yadong Qi Date: Mon, 23 Mar 2015 18:16:49 +0800 Subject: [PATCH 42/47] [SPARK-6397][SQL] Check the missingInput simply https://github.com/apache/spark/pull/5082 /cc liancheng Author: Yadong Qi Closes #5132 from watermen/sql-missingInput-new and squashes the following commits: 1e5bdc5 [Yadong Qi] Check the missingInput simply --- .../apache/spark/sql/catalyst/analysis/CheckAnalysis.scala | 5 ++--- .../org/apache/spark/sql/catalyst/plans/QueryPlan.scala | 5 +++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 4e8fc892f3eea..fb975ee5e7296 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -85,9 +85,8 @@ class CheckAnalysis { cleaned.foreach(checkValidAggregateExpression) - case o if o.children.nonEmpty && - !o.references.filter(_.name != "grouping__id").subsetOf(o.inputSet) => - val missingAttributes = (o.references -- o.inputSet).map(_.prettyString).mkString(",") + case o if o.children.nonEmpty && o.missingInput.nonEmpty => + val missingAttributes = o.missingInput.map(_.prettyString).mkString(",") val input = o.inputSet.map(_.prettyString).mkString(",") failAnalysis(s"resolved attributes $missingAttributes missing from $input") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 17a88e07de15f..400a6b2825c10 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.plans -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression} +import org.apache.spark.sql.catalyst.expressions.{VirtualColumn, Attribute, AttributeSet, Expression} import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.types.{ArrayType, DataType, StructField, StructType} @@ -48,7 +48,8 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy * Subclasses should override this method if they produce attributes internally as it is used by * assertions designed to prevent the construction of invalid plans. */ - def missingInput: AttributeSet = references -- inputSet + def missingInput: AttributeSet = (references -- inputSet) + .filter(_.name != VirtualColumn.groupingIdName) /** * Runs [[transform]] with `rule` on all expressions present in this query operator. From 474d1320c9b93c501710ad1cfa836b8284562a2c Mon Sep 17 00:00:00 2001 From: MechCoder Date: Mon, 23 Mar 2015 13:30:21 -0700 Subject: [PATCH 43/47] [SPARK-6308] [MLlib] [Sql] Override TypeName in VectorUDT and MatrixUDT Author: MechCoder Closes #5118 from MechCoder/spark-6308 and squashes the following commits: 6c8ffab [MechCoder] Add test for simpleString b966242 [MechCoder] [SPARK-6308] [MLlib][Sql] VectorUDT is displayed as vecto in dtypes --- .../src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala | 2 ++ .../src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala | 2 ++ .../scala/org/apache/spark/mllib/linalg/MatricesSuite.scala | 2 ++ .../test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala | 2 ++ 4 files changed, 8 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala index 849f44295f089..d1a174063caba 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala @@ -187,6 +187,8 @@ private[spark] class MatrixUDT extends UserDefinedType[Matrix] { override def hashCode(): Int = 1994 + override def typeName: String = "matrix" + private[spark] override def asNullable: MatrixUDT = this } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index 2cda9b252ee06..328dbe2ce11fa 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -185,6 +185,8 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] { override def hashCode: Int = 7919 + override def typeName: String = "vector" + private[spark] override def asNullable: VectorUDT = this } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala index 96f677db3f377..0d2cec58e2c03 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala @@ -436,5 +436,7 @@ class MatricesSuite extends FunSuite { Seq(dm1, dm2, dm3, sm1, sm2, sm3).foreach { mat => assert(mat.toArray === mUDT.deserialize(mUDT.serialize(mat)).toArray) } + assert(mUDT.typeName == "matrix") + assert(mUDT.simpleString == "matrix") } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala index 5def899cea117..2839c4c289b2d 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala @@ -187,6 +187,8 @@ class VectorsSuite extends FunSuite { for (v <- Seq(dv0, dv1, sv0, sv1)) { assert(v === udt.deserialize(udt.serialize(v))) } + assert(udt.typeName == "vector") + assert(udt.simpleString == "vector") } test("fromBreeze") { From 6cd7058b369ec8d01938961f148734ee9eaf76de Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Mon, 23 Mar 2015 15:08:39 -0700 Subject: [PATCH 44/47] Revert "[SPARK-6122][Core] Upgrade Tachyon client version to 0.6.1." This reverts commit a41b9c6004cfee84bd56dfa1faf5a0cf084551ae. --- core/pom.xml | 2 +- .../spark/storage/TachyonBlockManager.scala | 27 ++++++++++--------- .../scala/org/apache/spark/util/Utils.scala | 4 +-- make-distribution.sh | 2 +- 4 files changed, 17 insertions(+), 18 deletions(-) diff --git a/core/pom.xml b/core/pom.xml index 868834dd505ef..6cd1965ec37c2 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -275,7 +275,7 @@ org.tachyonproject tachyon-client - 0.6.1 + 0.5.0 org.apache.hadoop diff --git a/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala index 2ab6a8f3ec1d4..af873034215a9 100644 --- a/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala @@ -20,8 +20,8 @@ package org.apache.spark.storage import java.text.SimpleDateFormat import java.util.{Date, Random} -import tachyon.TachyonURI -import tachyon.client.{TachyonFile, TachyonFS} +import tachyon.client.TachyonFS +import tachyon.client.TachyonFile import org.apache.spark.Logging import org.apache.spark.executor.ExecutorExitCode @@ -40,7 +40,7 @@ private[spark] class TachyonBlockManager( val master: String) extends Logging { - val client = if (master != null && master != "") TachyonFS.get(new TachyonURI(master)) else null + val client = if (master != null && master != "") TachyonFS.get(master) else null if (client == null) { logError("Failed to connect to the Tachyon as the master address is not configured") @@ -60,11 +60,11 @@ private[spark] class TachyonBlockManager( addShutdownHook() def removeFile(file: TachyonFile): Boolean = { - client.delete(new TachyonURI(file.getPath()), false) + client.delete(file.getPath(), false) } def fileExists(file: TachyonFile): Boolean = { - client.exist(new TachyonURI(file.getPath())) + client.exist(file.getPath()) } def getFile(filename: String): TachyonFile = { @@ -81,7 +81,7 @@ private[spark] class TachyonBlockManager( if (old != null) { old } else { - val path = new TachyonURI(s"${tachyonDirs(dirId)}/${"%02x".format(subDirId)}") + val path = tachyonDirs(dirId) + "/" + "%02x".format(subDirId) client.mkdir(path) val newDir = client.getFile(path) subDirs(dirId)(subDirId) = newDir @@ -89,7 +89,7 @@ private[spark] class TachyonBlockManager( } } } - val filePath = new TachyonURI(s"$subDir/$filename") + val filePath = subDir + "/" + filename if(!client.exist(filePath)) { client.createFile(filePath) } @@ -101,7 +101,7 @@ private[spark] class TachyonBlockManager( // TODO: Some of the logic here could be consolidated/de-duplicated with that in the DiskStore. private def createTachyonDirs(): Array[TachyonFile] = { - logDebug(s"Creating tachyon directories at root dirs '$rootDirs'") + logDebug("Creating tachyon directories at root dirs '" + rootDirs + "'") val dateFormat = new SimpleDateFormat("yyyyMMddHHmmss") rootDirs.split(",").map { rootDir => var foundLocalDir = false @@ -113,21 +113,22 @@ private[spark] class TachyonBlockManager( tries += 1 try { tachyonDirId = "%s-%04x".format(dateFormat.format(new Date), rand.nextInt(65536)) - val path = new TachyonURI(s"$rootDir/spark-tachyon-$tachyonDirId") + val path = rootDir + "/" + "spark-tachyon-" + tachyonDirId if (!client.exist(path)) { foundLocalDir = client.mkdir(path) tachyonDir = client.getFile(path) } } catch { case e: Exception => - logWarning(s"Attempt $tries to create tachyon dir $tachyonDir failed", e) + logWarning("Attempt " + tries + " to create tachyon dir " + tachyonDir + " failed", e) } } if (!foundLocalDir) { - logError(s"Failed $MAX_DIR_CREATION_ATTEMPTS attempts to create tachyon dir in $rootDir") + logError("Failed " + MAX_DIR_CREATION_ATTEMPTS + " attempts to create tachyon dir in " + + rootDir) System.exit(ExecutorExitCode.TACHYON_STORE_FAILED_TO_CREATE_DIR) } - logInfo(s"Created tachyon directory at $tachyonDir") + logInfo("Created tachyon directory at " + tachyonDir) tachyonDir } } @@ -144,7 +145,7 @@ private[spark] class TachyonBlockManager( } } catch { case e: Exception => - logError(s"Exception while deleting tachyon spark dir: $tachyonDir", e) + logError("Exception while deleting tachyon spark dir: " + tachyonDir, e) } } client.close() diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 91d833295e376..fa56bb09e4e5c 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -42,8 +42,6 @@ import org.apache.hadoop.security.UserGroupInformation import org.apache.log4j.PropertyConfigurator import org.eclipse.jetty.util.MultiException import org.json4s._ - -import tachyon.TachyonURI import tachyon.client.{TachyonFS, TachyonFile} import org.apache.spark._ @@ -972,7 +970,7 @@ private[spark] object Utils extends Logging { * Delete a file or directory and its contents recursively. */ def deleteRecursively(dir: TachyonFile, client: TachyonFS) { - if (!client.delete(new TachyonURI(dir.getPath()), true)) { + if (!client.delete(dir.getPath(), true)) { throw new IOException("Failed to delete the tachyon dir: " + dir) } } diff --git a/make-distribution.sh b/make-distribution.sh index 8162fe94c1af0..9ed1abfe8c598 100755 --- a/make-distribution.sh +++ b/make-distribution.sh @@ -32,7 +32,7 @@ SPARK_HOME="$(cd "`dirname "$0"`"; pwd)" DISTDIR="$SPARK_HOME/dist" SPARK_TACHYON=false -TACHYON_VERSION="0.6.1" +TACHYON_VERSION="0.5.0" TACHYON_TGZ="tachyon-${TACHYON_VERSION}-bin.tar.gz" TACHYON_URL="https://github.com/amplab/tachyon/releases/download/v${TACHYON_VERSION}/${TACHYON_TGZ}" From bfd3ee9f76aaab3dcde71d92e2b8ca60a0e42262 Mon Sep 17 00:00:00 2001 From: Volodymyr Lyubinets Date: Mon, 23 Mar 2015 17:00:27 -0700 Subject: [PATCH 45/47] [SPARK-6124] Support jdbc connection properties in OPTIONS part of the query One more thing if this PR is considered to be OK - it might make sense to add extra .jdbc() API's that take Properties to SQLContext. Author: Volodymyr Lyubinets Closes #4859 from vlyubin/jdbcProperties and squashes the following commits: 7a8cfda [Volodymyr Lyubinets] Support jdbc connection properties in OPTIONS part of the query --- .../org/apache/spark/sql/jdbc/JDBCRDD.scala | 14 +++-- .../apache/spark/sql/jdbc/JDBCRelation.scala | 19 ++++--- .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 55 +++++++++++++------ 3 files changed, 59 insertions(+), 29 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala index 76f8593180e85..463e1dcc268bc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.jdbc import java.sql.{Connection, DriverManager, ResultSet, ResultSetMetaData, SQLException} +import java.util.Properties import org.apache.commons.lang.StringEscapeUtils.escapeSql import org.apache.spark.{Logging, Partition, SparkContext, TaskContext} @@ -90,9 +91,9 @@ private[sql] object JDBCRDD extends Logging { * @throws SQLException if the table specification is garbage. * @throws SQLException if the table contains an unsupported type. */ - def resolveTable(url: String, table: String): StructType = { + def resolveTable(url: String, table: String, properties: Properties): StructType = { val quirks = DriverQuirks.get(url) - val conn: Connection = DriverManager.getConnection(url) + val conn: Connection = DriverManager.getConnection(url, properties) try { val rs = conn.prepareStatement(s"SELECT * FROM $table WHERE 1=0").executeQuery() try { @@ -147,7 +148,7 @@ private[sql] object JDBCRDD extends Logging { * * @return A function that loads the driver and connects to the url. */ - def getConnector(driver: String, url: String): () => Connection = { + def getConnector(driver: String, url: String, properties: Properties): () => Connection = { () => { try { if (driver != null) Class.forName(driver) @@ -156,7 +157,7 @@ private[sql] object JDBCRDD extends Logging { logWarning(s"Couldn't find class $driver", e); } } - DriverManager.getConnection(url) + DriverManager.getConnection(url, properties) } } /** @@ -179,6 +180,7 @@ private[sql] object JDBCRDD extends Logging { schema: StructType, driver: String, url: String, + properties: Properties, fqTable: String, requiredColumns: Array[String], filters: Array[Filter], @@ -189,7 +191,7 @@ private[sql] object JDBCRDD extends Logging { return new JDBCRDD( sc, - getConnector(driver, url), + getConnector(driver, url, properties), prunedSchema, fqTable, requiredColumns, @@ -361,7 +363,7 @@ private[sql] class JDBCRDD( var ans = 0L var j = 0 while (j < bytes.size) { - ans = 256*ans + (255 & bytes(j)) + ans = 256 * ans + (255 & bytes(j)) j = j + 1; } mutableRow.setLong(i, ans) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala index df687e6da9bea..4fa84dc076f7e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala @@ -17,16 +17,17 @@ package org.apache.spark.sql.jdbc -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.expressions.Row -import org.apache.spark.sql.types.StructType +import java.sql.DriverManager +import java.util.Properties import scala.collection.mutable.ArrayBuffer -import java.sql.DriverManager import org.apache.spark.Partition +import org.apache.spark.rdd.RDD import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.catalyst.expressions.Row import org.apache.spark.sql.sources._ +import org.apache.spark.sql.types.StructType /** * Data corresponding to one partition of a JDBCRDD. @@ -115,18 +116,21 @@ private[sql] class DefaultSource extends RelationProvider { numPartitions.toInt) } val parts = JDBCRelation.columnPartition(partitionInfo) - JDBCRelation(url, table, parts)(sqlContext) + val properties = new Properties() // Additional properties that we will pass to getConnection + parameters.foreach(kv => properties.setProperty(kv._1, kv._2)) + JDBCRelation(url, table, parts, properties)(sqlContext) } } private[sql] case class JDBCRelation( url: String, table: String, - parts: Array[Partition])(@transient val sqlContext: SQLContext) + parts: Array[Partition], + properties: Properties = new Properties())(@transient val sqlContext: SQLContext) extends BaseRelation with PrunedFilteredScan { - override val schema: StructType = JDBCRDD.resolveTable(url, table) + override val schema: StructType = JDBCRDD.resolveTable(url, table, properties) override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = { val driver: String = DriverManager.getDriver(url).getClass.getCanonicalName @@ -135,6 +139,7 @@ private[sql] case class JDBCRelation( schema, driver, url, + properties, table, requiredColumns, filters, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 5eb6ab2e92e8b..592ed4b23b7d3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -19,22 +19,31 @@ package org.apache.spark.sql.jdbc import java.math.BigDecimal import java.sql.DriverManager -import java.util.{Calendar, GregorianCalendar} +import java.util.{Calendar, GregorianCalendar, Properties} import org.apache.spark.sql.test._ +import org.h2.jdbc.JdbcSQLException import org.scalatest.{FunSuite, BeforeAndAfter} import TestSQLContext._ import TestSQLContext.implicits._ class JDBCSuite extends FunSuite with BeforeAndAfter { val url = "jdbc:h2:mem:testdb0" + val urlWithUserAndPass = "jdbc:h2:mem:testdb0;user=testUser;password=testPass" var conn: java.sql.Connection = null val testBytes = Array[Byte](99.toByte, 134.toByte, 135.toByte, 200.toByte, 205.toByte) before { Class.forName("org.h2.Driver") - conn = DriverManager.getConnection(url) + // Extra properties that will be specified for our database. We need these to test + // usage of parameters from OPTIONS clause in queries. + val properties = new Properties() + properties.setProperty("user", "testUser") + properties.setProperty("password", "testPass") + properties.setProperty("rowId", "false") + + conn = DriverManager.getConnection(url, properties) conn.prepareStatement("create schema test").executeUpdate() conn.prepareStatement("create table test.people (name TEXT(32) NOT NULL, theid INTEGER NOT NULL)").executeUpdate() conn.prepareStatement("insert into test.people values ('fred', 1)").executeUpdate() @@ -46,15 +55,15 @@ class JDBCSuite extends FunSuite with BeforeAndAfter { s""" |CREATE TEMPORARY TABLE foobar |USING org.apache.spark.sql.jdbc - |OPTIONS (url '$url', dbtable 'TEST.PEOPLE') + |OPTIONS (url '$url', dbtable 'TEST.PEOPLE', user 'testUser', password 'testPass') """.stripMargin.replaceAll("\n", " ")) sql( s""" |CREATE TEMPORARY TABLE parts |USING org.apache.spark.sql.jdbc - |OPTIONS (url '$url', dbtable 'TEST.PEOPLE', - |partitionColumn 'THEID', lowerBound '1', upperBound '4', numPartitions '3') + |OPTIONS (url '$url', dbtable 'TEST.PEOPLE', user 'testUser', password 'testPass', + | partitionColumn 'THEID', lowerBound '1', upperBound '4', numPartitions '3') """.stripMargin.replaceAll("\n", " ")) conn.prepareStatement("create table test.inttypes (a INT, b BOOLEAN, c TINYINT, " @@ -68,12 +77,12 @@ class JDBCSuite extends FunSuite with BeforeAndAfter { s""" |CREATE TEMPORARY TABLE inttypes |USING org.apache.spark.sql.jdbc - |OPTIONS (url '$url', dbtable 'TEST.INTTYPES') + |OPTIONS (url '$url', dbtable 'TEST.INTTYPES', user 'testUser', password 'testPass') """.stripMargin.replaceAll("\n", " ")) conn.prepareStatement("create table test.strtypes (a BINARY(20), b VARCHAR(20), " + "c VARCHAR_IGNORECASE(20), d CHAR(20), e BLOB, f CLOB)").executeUpdate() - var stmt = conn.prepareStatement("insert into test.strtypes values (?, ?, ?, ?, ?, ?)") + val stmt = conn.prepareStatement("insert into test.strtypes values (?, ?, ?, ?, ?, ?)") stmt.setBytes(1, testBytes) stmt.setString(2, "Sensitive") stmt.setString(3, "Insensitive") @@ -85,7 +94,7 @@ class JDBCSuite extends FunSuite with BeforeAndAfter { s""" |CREATE TEMPORARY TABLE strtypes |USING org.apache.spark.sql.jdbc - |OPTIONS (url '$url', dbtable 'TEST.STRTYPES') + |OPTIONS (url '$url', dbtable 'TEST.STRTYPES', user 'testUser', password 'testPass') """.stripMargin.replaceAll("\n", " ")) conn.prepareStatement("create table test.timetypes (a TIME, b DATE, c TIMESTAMP)" @@ -97,7 +106,7 @@ class JDBCSuite extends FunSuite with BeforeAndAfter { s""" |CREATE TEMPORARY TABLE timetypes |USING org.apache.spark.sql.jdbc - |OPTIONS (url '$url', dbtable 'TEST.TIMETYPES') + |OPTIONS (url '$url', dbtable 'TEST.TIMETYPES', user 'testUser', password 'testPass') """.stripMargin.replaceAll("\n", " ")) @@ -112,7 +121,7 @@ class JDBCSuite extends FunSuite with BeforeAndAfter { s""" |CREATE TEMPORARY TABLE flttypes |USING org.apache.spark.sql.jdbc - |OPTIONS (url '$url', dbtable 'TEST.FLTTYPES') + |OPTIONS (url '$url', dbtable 'TEST.FLTTYPES', user 'testUser', password 'testPass') """.stripMargin.replaceAll("\n", " ")) // Untested: IDENTITY, OTHER, UUID, ARRAY, and GEOMETRY types. @@ -174,16 +183,17 @@ class JDBCSuite extends FunSuite with BeforeAndAfter { } test("Basic API") { - assert(TestSQLContext.jdbc(url, "TEST.PEOPLE").collect.size == 3) + assert(TestSQLContext.jdbc(urlWithUserAndPass, "TEST.PEOPLE").collect.size == 3) } test("Partitioning via JDBCPartitioningInfo API") { - assert(TestSQLContext.jdbc(url, "TEST.PEOPLE", "THEID", 0, 4, 3).collect.size == 3) + assert(TestSQLContext.jdbc(urlWithUserAndPass, "TEST.PEOPLE", "THEID", 0, 4, 3) + .collect.size == 3) } test("Partitioning via list-of-where-clauses API") { val parts = Array[String]("THEID < 2", "THEID >= 2") - assert(TestSQLContext.jdbc(url, "TEST.PEOPLE", parts).collect.size == 3) + assert(TestSQLContext.jdbc(urlWithUserAndPass, "TEST.PEOPLE", parts).collect.size == 3) } test("H2 integral types") { @@ -216,7 +226,6 @@ class JDBCSuite extends FunSuite with BeforeAndAfter { assert(rows(0).getString(5).equals("I am a clob!")) } - test("H2 time types") { val rows = sql("SELECT * FROM timetypes").collect() val cal = new GregorianCalendar(java.util.Locale.ROOT) @@ -246,17 +255,31 @@ class JDBCSuite extends FunSuite with BeforeAndAfter { .equals(new BigDecimal("123456789012345.54321543215432100000"))) } - test("SQL query as table name") { sql( s""" |CREATE TEMPORARY TABLE hack |USING org.apache.spark.sql.jdbc - |OPTIONS (url '$url', dbtable '(SELECT B, B*B FROM TEST.FLTTYPES)') + |OPTIONS (url '$url', dbtable '(SELECT B, B*B FROM TEST.FLTTYPES)', + | user 'testUser', password 'testPass') """.stripMargin.replaceAll("\n", " ")) val rows = sql("SELECT * FROM hack").collect() assert(rows(0).getDouble(0) == 1.00000011920928955) // Yes, I meant ==. // For some reason, H2 computes this square incorrectly... assert(math.abs(rows(0).getDouble(1) - 1.00000023841859331) < 1e-12) } + + test("Pass extra properties via OPTIONS") { + // We set rowId to false during setup, which means that _ROWID_ column should be absent from + // all tables. If rowId is true (default), the query below doesn't throw an exception. + intercept[JdbcSQLException] { + sql( + s""" + |CREATE TEMPORARY TABLE abc + |USING org.apache.spark.sql.jdbc + |OPTIONS (url '$url', dbtable '(SELECT _ROWID_ FROM test.people)', + | user 'testUser', password 'testPass') + """.stripMargin.replaceAll("\n", " ")) + } + } } From 4ce2782a61e23ed0326faac2ee97a9bd36ec8963 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 23 Mar 2015 23:41:06 -0700 Subject: [PATCH 46/47] [SPARK-6428] Added explicit types for all public methods in core. Author: Reynold Xin Closes #5125 from rxin/core-explicit-type and squashes the following commits: f471415 [Reynold Xin] Revert style checker changes. 81b66e4 [Reynold Xin] Code review feedback. a7533e3 [Reynold Xin] Mima excludes. 1d795f5 [Reynold Xin] [SPARK-6428] Added explicit types for all public methods in core. --- .../scala/org/apache/spark/Accumulators.scala | 23 +++-- .../scala/org/apache/spark/Dependency.scala | 6 +- .../scala/org/apache/spark/FutureAction.scala | 4 +- .../org/apache/spark/HeartbeatReceiver.scala | 2 +- .../org/apache/spark/MapOutputTracker.scala | 2 +- .../scala/org/apache/spark/Partitioner.scala | 4 +- .../apache/spark/SerializableWritable.scala | 6 +- .../scala/org/apache/spark/SparkConf.scala | 2 +- .../scala/org/apache/spark/SparkContext.scala | 48 +++++----- .../scala/org/apache/spark/TaskState.scala | 4 +- .../scala/org/apache/spark/TestUtils.scala | 2 +- .../org/apache/spark/api/java/JavaRDD.scala | 2 +- .../spark/api/java/JavaSparkContext.scala | 2 +- .../org/apache/spark/api/java/JavaUtils.scala | 35 ++++---- .../apache/spark/api/python/PythonRDD.scala | 39 ++++---- .../apache/spark/api/python/SerDeUtil.scala | 2 +- .../WriteInputFormatTestDataGenerator.scala | 27 +++--- .../apache/spark/broadcast/Broadcast.scala | 2 +- .../spark/broadcast/BroadcastManager.scala | 2 +- .../spark/broadcast/HttpBroadcast.scala | 4 +- .../broadcast/HttpBroadcastFactory.scala | 2 +- .../spark/broadcast/TorrentBroadcast.scala | 2 +- .../broadcast/TorrentBroadcastFactory.scala | 2 +- .../org/apache/spark/deploy/Client.scala | 4 +- .../apache/spark/deploy/ClientArguments.scala | 2 +- .../apache/spark/deploy/DeployMessage.scala | 2 +- .../spark/deploy/FaultToleranceTest.scala | 2 +- .../apache/spark/deploy/JsonProtocol.scala | 15 ++-- .../apache/spark/deploy/SparkHadoopUtil.scala | 2 +- .../org/apache/spark/deploy/SparkSubmit.scala | 2 +- .../spark/deploy/SparkSubmitArguments.scala | 2 +- .../spark/deploy/client/AppClient.scala | 2 +- .../deploy/history/FsHistoryProvider.scala | 4 +- .../spark/deploy/history/HistoryServer.scala | 10 +-- .../master/FileSystemPersistenceEngine.scala | 2 +- .../apache/spark/deploy/master/Master.scala | 2 +- .../deploy/master/RecoveryModeFactory.scala | 15 +++- .../spark/deploy/master/WorkerInfo.scala | 2 +- .../master/ZooKeeperPersistenceEngine.scala | 2 +- .../spark/deploy/master/ui/MasterPage.scala | 2 +- .../spark/deploy/worker/DriverRunner.scala | 22 +++-- .../apache/spark/deploy/worker/Worker.scala | 2 +- .../spark/deploy/worker/WorkerWatcher.scala | 2 +- .../CoarseGrainedExecutorBackend.scala | 2 +- .../apache/spark/executor/ExecutorActor.scala | 2 +- .../apache/spark/executor/TaskMetrics.scala | 90 +++++++++---------- .../spark/input/PortableDataStream.scala | 17 ++-- .../spark/mapred/SparkHadoopMapRedUtil.scala | 2 +- .../mapreduce/SparkHadoopMapReduceUtil.scala | 2 +- .../apache/spark/metrics/MetricsSystem.scala | 3 +- .../spark/metrics/sink/MetricsServlet.scala | 18 ++-- .../org/apache/spark/metrics/sink/Sink.scala | 4 +- .../spark/network/nio/BlockMessageArray.scala | 6 +- .../spark/network/nio/BufferMessage.scala | 10 +-- .../apache/spark/network/nio/Connection.scala | 8 +- .../spark/network/nio/ConnectionId.scala | 4 +- .../spark/network/nio/ConnectionManager.scala | 2 +- .../network/nio/ConnectionManagerId.scala | 2 +- .../apache/spark/network/nio/Message.scala | 6 +- .../spark/network/nio/MessageChunk.scala | 6 +- .../network/nio/MessageChunkHeader.scala | 4 +- .../apache/spark/partial/PartialResult.scala | 2 +- .../org/apache/spark/rdd/CartesianRDD.scala | 2 +- .../org/apache/spark/rdd/CoalescedRDD.scala | 10 +-- .../org/apache/spark/rdd/HadoopRDD.scala | 13 ++- .../scala/org/apache/spark/rdd/JdbcRDD.scala | 10 ++- .../apache/spark/rdd/MapPartitionsRDD.scala | 2 +- .../org/apache/spark/rdd/NewHadoopRDD.scala | 2 +- .../spark/rdd/ParallelCollectionRDD.scala | 2 +- .../spark/rdd/PartitionPruningRDD.scala | 10 ++- .../scala/org/apache/spark/rdd/PipedRDD.scala | 8 +- .../main/scala/org/apache/spark/rdd/RDD.scala | 12 +-- .../org/apache/spark/rdd/ShuffledRDD.scala | 4 +- .../org/apache/spark/rdd/SubtractedRDD.scala | 2 +- .../scala/org/apache/spark/rdd/UnionRDD.scala | 2 +- .../spark/rdd/ZippedPartitionsRDD.scala | 2 +- .../spark/scheduler/AccumulableInfo.scala | 7 +- .../apache/spark/scheduler/DAGScheduler.scala | 2 +- .../scheduler/EventLoggingListener.scala | 61 ++++++++----- .../apache/spark/scheduler/JobLogger.scala | 2 +- .../apache/spark/scheduler/JobWaiter.scala | 2 +- .../scheduler/OutputCommitCoordinator.scala | 2 +- .../apache/spark/scheduler/ResultTask.scala | 2 +- .../spark/scheduler/ShuffleMapTask.scala | 4 +- .../spark/scheduler/SparkListener.scala | 4 +- .../org/apache/spark/scheduler/Stage.scala | 2 +- .../apache/spark/scheduler/TaskLocation.scala | 19 ++-- .../spark/scheduler/TaskSchedulerImpl.scala | 16 ++-- .../spark/scheduler/TaskSetManager.scala | 13 +-- .../CoarseGrainedSchedulerBackend.scala | 2 +- .../cluster/YarnSchedulerBackend.scala | 2 +- .../scheduler/cluster/mesos/MemoryUtils.scala | 2 +- .../cluster/mesos/MesosSchedulerBackend.scala | 2 +- .../spark/scheduler/local/LocalBackend.scala | 4 +- .../spark/serializer/JavaSerializer.scala | 5 +- .../spark/serializer/KryoSerializer.scala | 2 +- .../shuffle/FileShuffleBlockManager.scala | 4 +- .../shuffle/IndexShuffleBlockManager.scala | 4 +- .../org/apache/spark/storage/BlockId.scala | 34 +++---- .../apache/spark/storage/BlockManagerId.scala | 8 +- .../spark/storage/BlockManagerMaster.scala | 2 +- .../storage/BlockManagerMasterActor.scala | 4 +- .../storage/BlockManagerSlaveActor.scala | 2 +- .../spark/storage/BlockObjectWriter.scala | 17 ++-- .../apache/spark/storage/FileSegment.scala | 4 +- .../org/apache/spark/storage/RDDInfo.scala | 4 +- .../apache/spark/storage/StorageLevel.scala | 16 ++-- .../spark/storage/StorageStatusListener.scala | 7 +- .../spark/storage/TachyonFileSegment.scala | 4 +- .../scala/org/apache/spark/ui/SparkUI.scala | 2 +- .../scala/org/apache/spark/ui/UIUtils.scala | 6 +- .../apache/spark/ui/UIWorkloadGenerator.scala | 4 +- .../apache/spark/ui/exec/ExecutorsTab.scala | 10 +-- .../spark/ui/jobs/JobProgressListener.scala | 14 +-- .../org/apache/spark/ui/jobs/JobsTab.scala | 2 +- .../org/apache/spark/ui/jobs/StagePage.scala | 10 ++- .../org/apache/spark/ui/jobs/StagesTab.scala | 6 +- .../org/apache/spark/ui/jobs/UIData.scala | 10 +-- .../apache/spark/ui/storage/StorageTab.scala | 12 +-- .../spark/util/CompletionIterator.scala | 10 +-- .../org/apache/spark/util/Distribution.scala | 2 +- .../org/apache/spark/util/ManualClock.scala | 30 +++---- .../apache/spark/util/MetadataCleaner.scala | 7 +- .../org/apache/spark/util/MutablePair.scala | 2 +- .../apache/spark/util/ParentClassLoader.scala | 2 +- .../spark/util/SerializableBuffer.scala | 2 +- .../org/apache/spark/util/StatCounter.scala | 4 +- .../util/TimeStampedWeakValueHashMap.scala | 6 +- .../scala/org/apache/spark/util/Utils.scala | 30 +++---- .../apache/spark/util/collection/BitSet.scala | 4 +- .../collection/ExternalAppendOnlyMap.scala | 4 +- .../util/collection/ExternalSorter.scala | 2 +- .../spark/util/collection/OpenHashMap.scala | 6 +- .../spark/util/collection/OpenHashSet.scala | 4 +- .../collection/PrimitiveKeyOpenHashMap.scala | 8 +- .../apache/spark/util/collection/Utils.scala | 2 +- .../spark/util/logging/FileAppender.scala | 4 +- .../spark/util/random/RandomSampler.scala | 44 +++++---- .../util/random/StratifiedSamplingUtils.scala | 2 +- .../spark/util/random/XORShiftRandom.scala | 2 +- project/MimaExcludes.scala | 8 +- scalastyle-config.xml | 4 +- 142 files changed, 597 insertions(+), 526 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/Accumulators.scala b/core/src/main/scala/org/apache/spark/Accumulators.scala index bcf832467f00b..330df1d59a9b1 100644 --- a/core/src/main/scala/org/apache/spark/Accumulators.scala +++ b/core/src/main/scala/org/apache/spark/Accumulators.scala @@ -18,8 +18,6 @@ package org.apache.spark import java.io.{ObjectInputStream, Serializable} -import java.util.concurrent.atomic.AtomicLong -import java.lang.ThreadLocal import scala.collection.generic.Growable import scala.collection.mutable.Map @@ -109,7 +107,7 @@ class Accumulable[R, T] ( * The typical use of this method is to directly mutate the local value, eg., to add * an element to a Set. */ - def localValue = value_ + def localValue: R = value_ /** * Set the accumulator's value; only allowed on master. @@ -137,7 +135,7 @@ class Accumulable[R, T] ( Accumulators.register(this, false) } - override def toString = if (value_ == null) "null" else value_.toString + override def toString: String = if (value_ == null) "null" else value_.toString } /** @@ -257,22 +255,22 @@ object AccumulatorParam { implicit object DoubleAccumulatorParam extends AccumulatorParam[Double] { def addInPlace(t1: Double, t2: Double): Double = t1 + t2 - def zero(initialValue: Double) = 0.0 + def zero(initialValue: Double): Double = 0.0 } implicit object IntAccumulatorParam extends AccumulatorParam[Int] { def addInPlace(t1: Int, t2: Int): Int = t1 + t2 - def zero(initialValue: Int) = 0 + def zero(initialValue: Int): Int = 0 } implicit object LongAccumulatorParam extends AccumulatorParam[Long] { - def addInPlace(t1: Long, t2: Long) = t1 + t2 - def zero(initialValue: Long) = 0L + def addInPlace(t1: Long, t2: Long): Long = t1 + t2 + def zero(initialValue: Long): Long = 0L } implicit object FloatAccumulatorParam extends AccumulatorParam[Float] { - def addInPlace(t1: Float, t2: Float) = t1 + t2 - def zero(initialValue: Float) = 0f + def addInPlace(t1: Float, t2: Float): Float = t1 + t2 + def zero(initialValue: Float): Float = 0f } // TODO: Add AccumulatorParams for other types, e.g. lists and strings @@ -351,6 +349,7 @@ private[spark] object Accumulators extends Logging { } } - def stringifyPartialValue(partialValue: Any) = "%s".format(partialValue) - def stringifyValue(value: Any) = "%s".format(value) + def stringifyPartialValue(partialValue: Any): String = "%s".format(partialValue) + + def stringifyValue(value: Any): String = "%s".format(value) } diff --git a/core/src/main/scala/org/apache/spark/Dependency.scala b/core/src/main/scala/org/apache/spark/Dependency.scala index 9a7cd4523e5ab..fc8cdde9348ee 100644 --- a/core/src/main/scala/org/apache/spark/Dependency.scala +++ b/core/src/main/scala/org/apache/spark/Dependency.scala @@ -74,7 +74,7 @@ class ShuffleDependency[K, V, C]( val mapSideCombine: Boolean = false) extends Dependency[Product2[K, V]] { - override def rdd = _rdd.asInstanceOf[RDD[Product2[K, V]]] + override def rdd: RDD[Product2[K, V]] = _rdd.asInstanceOf[RDD[Product2[K, V]]] val shuffleId: Int = _rdd.context.newShuffleId() @@ -91,7 +91,7 @@ class ShuffleDependency[K, V, C]( */ @DeveloperApi class OneToOneDependency[T](rdd: RDD[T]) extends NarrowDependency[T](rdd) { - override def getParents(partitionId: Int) = List(partitionId) + override def getParents(partitionId: Int): List[Int] = List(partitionId) } @@ -107,7 +107,7 @@ class OneToOneDependency[T](rdd: RDD[T]) extends NarrowDependency[T](rdd) { class RangeDependency[T](rdd: RDD[T], inStart: Int, outStart: Int, length: Int) extends NarrowDependency[T](rdd) { - override def getParents(partitionId: Int) = { + override def getParents(partitionId: Int): List[Int] = { if (partitionId >= outStart && partitionId < outStart + length) { List(partitionId - outStart + inStart) } else { diff --git a/core/src/main/scala/org/apache/spark/FutureAction.scala b/core/src/main/scala/org/apache/spark/FutureAction.scala index e97a7375a267b..91f9ef8ce7185 100644 --- a/core/src/main/scala/org/apache/spark/FutureAction.scala +++ b/core/src/main/scala/org/apache/spark/FutureAction.scala @@ -168,7 +168,7 @@ class SimpleFutureAction[T] private[spark](jobWaiter: JobWaiter[_], resultFunc: } } - def jobIds = Seq(jobWaiter.jobId) + def jobIds: Seq[Int] = Seq(jobWaiter.jobId) } @@ -276,7 +276,7 @@ class ComplexFutureAction[T] extends FutureAction[T] { override def value: Option[Try[T]] = p.future.value - def jobIds = jobs + def jobIds: Seq[Int] = jobs } diff --git a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala index 69178da1a7773..715f292f03469 100644 --- a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala +++ b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala @@ -65,7 +65,7 @@ private[spark] class HeartbeatReceiver(sc: SparkContext, scheduler: TaskSchedule super.preStart() } - override def receiveWithLogging = { + override def receiveWithLogging: PartialFunction[Any, Unit] = { case Heartbeat(executorId, taskMetrics, blockManagerId) => val unknownExecutor = !scheduler.executorHeartbeatReceived( executorId, taskMetrics, blockManagerId) diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 6e4edc7c80d7a..c9426c5de23a2 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -43,7 +43,7 @@ private[spark] class MapOutputTrackerMasterActor(tracker: MapOutputTrackerMaster extends Actor with ActorLogReceive with Logging { val maxAkkaFrameSize = AkkaUtils.maxFrameSizeBytes(conf) - override def receiveWithLogging = { + override def receiveWithLogging: PartialFunction[Any, Unit] = { case GetMapOutputStatuses(shuffleId: Int) => val hostPort = sender.path.address.hostPort logInfo("Asked to send map output locations for shuffle " + shuffleId + " to " + hostPort) diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala b/core/src/main/scala/org/apache/spark/Partitioner.scala index e53a78ead2c0e..b8d244408bc5b 100644 --- a/core/src/main/scala/org/apache/spark/Partitioner.scala +++ b/core/src/main/scala/org/apache/spark/Partitioner.scala @@ -76,7 +76,7 @@ object Partitioner { * produce an unexpected or incorrect result. */ class HashPartitioner(partitions: Int) extends Partitioner { - def numPartitions = partitions + def numPartitions: Int = partitions def getPartition(key: Any): Int = key match { case null => 0 @@ -154,7 +154,7 @@ class RangePartitioner[K : Ordering : ClassTag, V]( } } - def numPartitions = rangeBounds.length + 1 + def numPartitions: Int = rangeBounds.length + 1 private var binarySearch: ((Array[K], K) => Int) = CollectionsUtils.makeBinarySearch[K] diff --git a/core/src/main/scala/org/apache/spark/SerializableWritable.scala b/core/src/main/scala/org/apache/spark/SerializableWritable.scala index 55cb25946c2ad..cb2cae185256a 100644 --- a/core/src/main/scala/org/apache/spark/SerializableWritable.scala +++ b/core/src/main/scala/org/apache/spark/SerializableWritable.scala @@ -28,8 +28,10 @@ import org.apache.spark.util.Utils @DeveloperApi class SerializableWritable[T <: Writable](@transient var t: T) extends Serializable { - def value = t - override def toString = t.toString + + def value: T = t + + override def toString: String = t.toString private def writeObject(out: ObjectOutputStream): Unit = Utils.tryOrIOException { out.defaultWriteObject() diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index 2ca19f53d2f07..0c123c96b8d7b 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -133,7 +133,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { } /** Set multiple parameters together */ - def setAll(settings: Traversable[(String, String)]) = { + def setAll(settings: Traversable[(String, String)]): SparkConf = { this.settings.putAll(settings.toMap.asJava) this } diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 228ff715fe7cb..a70be16f77eeb 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -986,7 +986,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli union(Seq(first) ++ rest) /** Get an RDD that has no partitions or elements. */ - def emptyRDD[T: ClassTag] = new EmptyRDD[T](this) + def emptyRDD[T: ClassTag]: EmptyRDD[T] = new EmptyRDD[T](this) // Methods for creating shared variables @@ -994,7 +994,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * Create an [[org.apache.spark.Accumulator]] variable of a given type, which tasks can "add" * values to using the `+=` method. Only the driver can access the accumulator's `value`. */ - def accumulator[T](initialValue: T)(implicit param: AccumulatorParam[T]) = + def accumulator[T](initialValue: T)(implicit param: AccumulatorParam[T]): Accumulator[T] = { val acc = new Accumulator(initialValue, param) cleaner.foreach(_.registerAccumulatorForCleanup(acc)) @@ -1006,7 +1006,8 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * in the Spark UI. Tasks can "add" values to the accumulator using the `+=` method. Only the * driver can access the accumulator's `value`. */ - def accumulator[T](initialValue: T, name: String)(implicit param: AccumulatorParam[T]) = { + def accumulator[T](initialValue: T, name: String)(implicit param: AccumulatorParam[T]) + : Accumulator[T] = { val acc = new Accumulator(initialValue, param, Some(name)) cleaner.foreach(_.registerAccumulatorForCleanup(acc)) acc @@ -1018,7 +1019,8 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * @tparam R accumulator result type * @tparam T type that can be added to the accumulator */ - def accumulable[R, T](initialValue: R)(implicit param: AccumulableParam[R, T]) = { + def accumulable[R, T](initialValue: R)(implicit param: AccumulableParam[R, T]) + : Accumulable[R, T] = { val acc = new Accumulable(initialValue, param) cleaner.foreach(_.registerAccumulatorForCleanup(acc)) acc @@ -1031,7 +1033,8 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * @tparam R accumulator result type * @tparam T type that can be added to the accumulator */ - def accumulable[R, T](initialValue: R, name: String)(implicit param: AccumulableParam[R, T]) = { + def accumulable[R, T](initialValue: R, name: String)(implicit param: AccumulableParam[R, T]) + : Accumulable[R, T] = { val acc = new Accumulable(initialValue, param, Some(name)) cleaner.foreach(_.registerAccumulatorForCleanup(acc)) acc @@ -1209,7 +1212,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli override def killExecutor(executorId: String): Boolean = super.killExecutor(executorId) /** The version of Spark on which this application is running. */ - def version = SPARK_VERSION + def version: String = SPARK_VERSION /** * Return a map from the slave to the max memory available for caching and the remaining @@ -1659,7 +1662,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli } } - def getCheckpointDir = checkpointDir + def getCheckpointDir: Option[String] = checkpointDir /** Default level of parallelism to use when not given by user (e.g. parallelize and makeRDD). */ def defaultParallelism: Int = { @@ -1900,28 +1903,28 @@ object SparkContext extends Logging { "backward compatibility.", "1.3.0") object DoubleAccumulatorParam extends AccumulatorParam[Double] { def addInPlace(t1: Double, t2: Double): Double = t1 + t2 - def zero(initialValue: Double) = 0.0 + def zero(initialValue: Double): Double = 0.0 } @deprecated("Replaced by implicit objects in AccumulatorParam. This is kept here only for " + "backward compatibility.", "1.3.0") object IntAccumulatorParam extends AccumulatorParam[Int] { def addInPlace(t1: Int, t2: Int): Int = t1 + t2 - def zero(initialValue: Int) = 0 + def zero(initialValue: Int): Int = 0 } @deprecated("Replaced by implicit objects in AccumulatorParam. This is kept here only for " + "backward compatibility.", "1.3.0") object LongAccumulatorParam extends AccumulatorParam[Long] { - def addInPlace(t1: Long, t2: Long) = t1 + t2 - def zero(initialValue: Long) = 0L + def addInPlace(t1: Long, t2: Long): Long = t1 + t2 + def zero(initialValue: Long): Long = 0L } @deprecated("Replaced by implicit objects in AccumulatorParam. This is kept here only for " + "backward compatibility.", "1.3.0") object FloatAccumulatorParam extends AccumulatorParam[Float] { - def addInPlace(t1: Float, t2: Float) = t1 + t2 - def zero(initialValue: Float) = 0f + def addInPlace(t1: Float, t2: Float): Float = t1 + t2 + def zero(initialValue: Float): Float = 0f } // The following deprecated functions have already been moved to `object RDD` to @@ -1931,18 +1934,18 @@ object SparkContext extends Logging { @deprecated("Replaced by implicit functions in the RDD companion object. This is " + "kept here only for backward compatibility.", "1.3.0") def rddToPairRDDFunctions[K, V](rdd: RDD[(K, V)]) - (implicit kt: ClassTag[K], vt: ClassTag[V], ord: Ordering[K] = null) = { + (implicit kt: ClassTag[K], vt: ClassTag[V], ord: Ordering[K] = null): PairRDDFunctions[K, V] = RDD.rddToPairRDDFunctions(rdd) - } @deprecated("Replaced by implicit functions in the RDD companion object. This is " + "kept here only for backward compatibility.", "1.3.0") - def rddToAsyncRDDActions[T: ClassTag](rdd: RDD[T]) = RDD.rddToAsyncRDDActions(rdd) + def rddToAsyncRDDActions[T: ClassTag](rdd: RDD[T]): AsyncRDDActions[T] = + RDD.rddToAsyncRDDActions(rdd) @deprecated("Replaced by implicit functions in the RDD companion object. This is " + "kept here only for backward compatibility.", "1.3.0") def rddToSequenceFileRDDFunctions[K <% Writable: ClassTag, V <% Writable: ClassTag]( - rdd: RDD[(K, V)]) = { + rdd: RDD[(K, V)]): SequenceFileRDDFunctions[K, V] = { val kf = implicitly[K => Writable] val vf = implicitly[V => Writable] // Set the Writable class to null and `SequenceFileRDDFunctions` will use Reflection to get it @@ -1954,16 +1957,17 @@ object SparkContext extends Logging { @deprecated("Replaced by implicit functions in the RDD companion object. This is " + "kept here only for backward compatibility.", "1.3.0") def rddToOrderedRDDFunctions[K : Ordering : ClassTag, V: ClassTag]( - rdd: RDD[(K, V)]) = + rdd: RDD[(K, V)]): OrderedRDDFunctions[K, V, (K, V)] = RDD.rddToOrderedRDDFunctions(rdd) @deprecated("Replaced by implicit functions in the RDD companion object. This is " + "kept here only for backward compatibility.", "1.3.0") - def doubleRDDToDoubleRDDFunctions(rdd: RDD[Double]) = RDD.doubleRDDToDoubleRDDFunctions(rdd) + def doubleRDDToDoubleRDDFunctions(rdd: RDD[Double]): DoubleRDDFunctions = + RDD.doubleRDDToDoubleRDDFunctions(rdd) @deprecated("Replaced by implicit functions in the RDD companion object. This is " + "kept here only for backward compatibility.", "1.3.0") - def numericRDDToDoubleRDDFunctions[T](rdd: RDD[T])(implicit num: Numeric[T]) = + def numericRDDToDoubleRDDFunctions[T](rdd: RDD[T])(implicit num: Numeric[T]): DoubleRDDFunctions = RDD.numericRDDToDoubleRDDFunctions(rdd) // The following deprecated functions have already been moved to `object WritableFactory` to @@ -2134,7 +2138,7 @@ object SparkContext extends Logging { (backend, scheduler) case LOCAL_N_REGEX(threads) => - def localCpuCount = Runtime.getRuntime.availableProcessors() + def localCpuCount: Int = Runtime.getRuntime.availableProcessors() // local[*] estimates the number of cores on the machine; local[N] uses exactly N threads. val threadCount = if (threads == "*") localCpuCount else threads.toInt if (threadCount <= 0) { @@ -2146,7 +2150,7 @@ object SparkContext extends Logging { (backend, scheduler) case LOCAL_N_FAILURES_REGEX(threads, maxFailures) => - def localCpuCount = Runtime.getRuntime.availableProcessors() + def localCpuCount: Int = Runtime.getRuntime.availableProcessors() // local[*, M] means the number of cores on the computer with M failures // local[N, M] means exactly N threads with M failures val threadCount = if (threads == "*") localCpuCount else threads.toInt diff --git a/core/src/main/scala/org/apache/spark/TaskState.scala b/core/src/main/scala/org/apache/spark/TaskState.scala index c415fe99b105e..fe19f07e32d1b 100644 --- a/core/src/main/scala/org/apache/spark/TaskState.scala +++ b/core/src/main/scala/org/apache/spark/TaskState.scala @@ -27,9 +27,9 @@ private[spark] object TaskState extends Enumeration { type TaskState = Value - def isFailed(state: TaskState) = (LOST == state) || (FAILED == state) + def isFailed(state: TaskState): Boolean = (LOST == state) || (FAILED == state) - def isFinished(state: TaskState) = FINISHED_STATES.contains(state) + def isFinished(state: TaskState): Boolean = FINISHED_STATES.contains(state) def toMesos(state: TaskState): MesosTaskState = state match { case LAUNCHING => MesosTaskState.TASK_STARTING diff --git a/core/src/main/scala/org/apache/spark/TestUtils.scala b/core/src/main/scala/org/apache/spark/TestUtils.scala index 35b324ba6f573..398ca41e16151 100644 --- a/core/src/main/scala/org/apache/spark/TestUtils.scala +++ b/core/src/main/scala/org/apache/spark/TestUtils.scala @@ -107,7 +107,7 @@ private[spark] object TestUtils { private class JavaSourceFromString(val name: String, val code: String) extends SimpleJavaFileObject(createURI(name), SOURCE) { - override def getCharContent(ignoreEncodingErrors: Boolean) = code + override def getCharContent(ignoreEncodingErrors: Boolean): String = code } /** Creates a compiled class with the given name. Class file will be placed in destDir. */ diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala index 3e9beb670f7ad..18ccd625fc8d1 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala @@ -179,7 +179,7 @@ class JavaRDD[T](val rdd: RDD[T])(implicit val classTag: ClassTag[T]) def subtract(other: JavaRDD[T], p: Partitioner): JavaRDD[T] = wrapRDD(rdd.subtract(other, p)) - override def toString = rdd.toString + override def toString: String = rdd.toString /** Assign a name to this RDD */ def setName(name: String): JavaRDD[T] = { diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala index 6d6ed693be752..3be6783bba49d 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala @@ -108,7 +108,7 @@ class JavaSparkContext(val sc: SparkContext) private[spark] val env = sc.env - def statusTracker = new JavaSparkStatusTracker(sc) + def statusTracker: JavaSparkStatusTracker = new JavaSparkStatusTracker(sc) def isLocal: java.lang.Boolean = sc.isLocal diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala b/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala index 71b26737b8c02..8f9647eea9e25 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala @@ -17,6 +17,8 @@ package org.apache.spark.api.java +import java.util.Map.Entry + import com.google.common.base.Optional import java.{util => ju} @@ -30,8 +32,8 @@ private[spark] object JavaUtils { } // Workaround for SPARK-3926 / SI-8911 - def mapAsSerializableJavaMap[A, B](underlying: collection.Map[A, B]) = - new SerializableMapWrapper(underlying) + def mapAsSerializableJavaMap[A, B](underlying: collection.Map[A, B]): SerializableMapWrapper[A, B] + = new SerializableMapWrapper(underlying) // Implementation is copied from scala.collection.convert.Wrappers.MapWrapper, // but implements java.io.Serializable. It can't just be subclassed to make it @@ -40,36 +42,33 @@ private[spark] object JavaUtils { class SerializableMapWrapper[A, B](underlying: collection.Map[A, B]) extends ju.AbstractMap[A, B] with java.io.Serializable { self => - override def size = underlying.size + override def size: Int = underlying.size override def get(key: AnyRef): B = try { - underlying get key.asInstanceOf[A] match { - case None => null.asInstanceOf[B] - case Some(v) => v - } + underlying.getOrElse(key.asInstanceOf[A], null.asInstanceOf[B]) } catch { case ex: ClassCastException => null.asInstanceOf[B] } override def entrySet: ju.Set[ju.Map.Entry[A, B]] = new ju.AbstractSet[ju.Map.Entry[A, B]] { - def size = self.size + override def size: Int = self.size - def iterator = new ju.Iterator[ju.Map.Entry[A, B]] { + override def iterator: ju.Iterator[ju.Map.Entry[A, B]] = new ju.Iterator[ju.Map.Entry[A, B]] { val ui = underlying.iterator var prev : Option[A] = None - def hasNext = ui.hasNext + def hasNext: Boolean = ui.hasNext - def next() = { - val (k, v) = ui.next + def next(): Entry[A, B] = { + val (k, v) = ui.next() prev = Some(k) new ju.Map.Entry[A, B] { import scala.util.hashing.byteswap32 - def getKey = k - def getValue = v - def setValue(v1 : B) = self.put(k, v1) - override def hashCode = byteswap32(k.hashCode) + (byteswap32(v.hashCode) << 16) - override def equals(other: Any) = other match { + override def getKey: A = k + override def getValue: B = v + override def setValue(v1 : B): B = self.put(k, v1) + override def hashCode: Int = byteswap32(k.hashCode) + (byteswap32(v.hashCode) << 16) + override def equals(other: Any): Boolean = other match { case e: ju.Map.Entry[_, _] => k == e.getKey && v == e.getValue case _ => false } @@ -81,7 +80,7 @@ private[spark] object JavaUtils { case Some(k) => underlying match { case mm: mutable.Map[A, _] => - mm remove k + mm.remove(k) prev = None case _ => throw new UnsupportedOperationException("remove") diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 4c71b69069eb3..19f4c95fcad74 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -54,9 +54,11 @@ private[spark] class PythonRDD( val bufferSize = conf.getInt("spark.buffer.size", 65536) val reuse_worker = conf.getBoolean("spark.python.worker.reuse", true) - override def getPartitions = firstParent.partitions + override def getPartitions: Array[Partition] = firstParent.partitions - override val partitioner = if (preservePartitoning) firstParent.partitioner else None + override val partitioner: Option[Partitioner] = { + if (preservePartitoning) firstParent.partitioner else None + } override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = { val startTime = System.currentTimeMillis @@ -92,7 +94,7 @@ private[spark] class PythonRDD( // Return an iterator that read lines from the process's stdout val stream = new DataInputStream(new BufferedInputStream(worker.getInputStream, bufferSize)) val stdoutIterator = new Iterator[Array[Byte]] { - def next(): Array[Byte] = { + override def next(): Array[Byte] = { val obj = _nextObj if (hasNext) { _nextObj = read() @@ -175,7 +177,7 @@ private[spark] class PythonRDD( var _nextObj = read() - def hasNext = _nextObj != null + override def hasNext: Boolean = _nextObj != null } new InterruptibleIterator(context, stdoutIterator) } @@ -303,11 +305,10 @@ private class PythonException(msg: String, cause: Exception) extends RuntimeExce * Form an RDD[(Array[Byte], Array[Byte])] from key-value pairs returned from Python. * This is used by PySpark's shuffle operations. */ -private class PairwiseRDD(prev: RDD[Array[Byte]]) extends - RDD[(Long, Array[Byte])](prev) { - override def getPartitions = prev.partitions - override val partitioner = prev.partitioner - override def compute(split: Partition, context: TaskContext) = +private class PairwiseRDD(prev: RDD[Array[Byte]]) extends RDD[(Long, Array[Byte])](prev) { + override def getPartitions: Array[Partition] = prev.partitions + override val partitioner: Option[Partitioner] = prev.partitioner + override def compute(split: Partition, context: TaskContext): Iterator[(Long, Array[Byte])] = prev.iterator(split, context).grouped(2).map { case Seq(a, b) => (Utils.deserializeLongValue(a), b) case x => throw new SparkException("PairwiseRDD: unexpected value: " + x) @@ -435,7 +436,7 @@ private[spark] object PythonRDD extends Logging { keyConverterClass: String, valueConverterClass: String, minSplits: Int, - batchSize: Int) = { + batchSize: Int): JavaRDD[Array[Byte]] = { val keyClass = Option(keyClassMaybeNull).getOrElse("org.apache.hadoop.io.Text") val valueClass = Option(valueClassMaybeNull).getOrElse("org.apache.hadoop.io.Text") val kc = Utils.classForName(keyClass).asInstanceOf[Class[K]] @@ -462,7 +463,7 @@ private[spark] object PythonRDD extends Logging { keyConverterClass: String, valueConverterClass: String, confAsMap: java.util.HashMap[String, String], - batchSize: Int) = { + batchSize: Int): JavaRDD[Array[Byte]] = { val mergedConf = getMergedConf(confAsMap, sc.hadoopConfiguration()) val rdd = newAPIHadoopRDDFromClassNames[K, V, F](sc, @@ -488,7 +489,7 @@ private[spark] object PythonRDD extends Logging { keyConverterClass: String, valueConverterClass: String, confAsMap: java.util.HashMap[String, String], - batchSize: Int) = { + batchSize: Int): JavaRDD[Array[Byte]] = { val conf = PythonHadoopUtil.mapToConf(confAsMap) val rdd = newAPIHadoopRDDFromClassNames[K, V, F](sc, @@ -505,7 +506,7 @@ private[spark] object PythonRDD extends Logging { inputFormatClass: String, keyClass: String, valueClass: String, - conf: Configuration) = { + conf: Configuration): RDD[(K, V)] = { val kc = Utils.classForName(keyClass).asInstanceOf[Class[K]] val vc = Utils.classForName(valueClass).asInstanceOf[Class[V]] val fc = Utils.classForName(inputFormatClass).asInstanceOf[Class[F]] @@ -531,7 +532,7 @@ private[spark] object PythonRDD extends Logging { keyConverterClass: String, valueConverterClass: String, confAsMap: java.util.HashMap[String, String], - batchSize: Int) = { + batchSize: Int): JavaRDD[Array[Byte]] = { val mergedConf = getMergedConf(confAsMap, sc.hadoopConfiguration()) val rdd = hadoopRDDFromClassNames[K, V, F](sc, @@ -557,7 +558,7 @@ private[spark] object PythonRDD extends Logging { keyConverterClass: String, valueConverterClass: String, confAsMap: java.util.HashMap[String, String], - batchSize: Int) = { + batchSize: Int): JavaRDD[Array[Byte]] = { val conf = PythonHadoopUtil.mapToConf(confAsMap) val rdd = hadoopRDDFromClassNames[K, V, F](sc, @@ -686,7 +687,7 @@ private[spark] object PythonRDD extends Logging { pyRDD: JavaRDD[Array[Byte]], batchSerialized: Boolean, path: String, - compressionCodecClass: String) = { + compressionCodecClass: String): Unit = { saveAsHadoopFile( pyRDD, batchSerialized, path, "org.apache.hadoop.mapred.SequenceFileOutputFormat", null, null, null, null, new java.util.HashMap(), compressionCodecClass) @@ -711,7 +712,7 @@ private[spark] object PythonRDD extends Logging { keyConverterClass: String, valueConverterClass: String, confAsMap: java.util.HashMap[String, String], - compressionCodecClass: String) = { + compressionCodecClass: String): Unit = { val rdd = SerDeUtil.pythonToPairRDD(pyRDD, batchSerialized) val (kc, vc) = getKeyValueTypes(keyClass, valueClass).getOrElse( inferKeyValueTypes(rdd, keyConverterClass, valueConverterClass)) @@ -741,7 +742,7 @@ private[spark] object PythonRDD extends Logging { valueClass: String, keyConverterClass: String, valueConverterClass: String, - confAsMap: java.util.HashMap[String, String]) = { + confAsMap: java.util.HashMap[String, String]): Unit = { val rdd = SerDeUtil.pythonToPairRDD(pyRDD, batchSerialized) val (kc, vc) = getKeyValueTypes(keyClass, valueClass).getOrElse( inferKeyValueTypes(rdd, keyConverterClass, valueConverterClass)) @@ -766,7 +767,7 @@ private[spark] object PythonRDD extends Logging { confAsMap: java.util.HashMap[String, String], keyConverterClass: String, valueConverterClass: String, - useNewAPI: Boolean) = { + useNewAPI: Boolean): Unit = { val conf = PythonHadoopUtil.mapToConf(confAsMap) val converted = convertRDD(SerDeUtil.pythonToPairRDD(pyRDD, batchSerialized), keyConverterClass, valueConverterClass, new JavaToWritableConverter) diff --git a/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala b/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala index fb52a960e0765..257491e90dd66 100644 --- a/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala +++ b/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala @@ -84,7 +84,7 @@ private[spark] object SerDeUtil extends Logging { private var initialized = false // This should be called before trying to unpickle array.array from Python // In cluster mode, this should be put in closure - def initialize() = { + def initialize(): Unit = { synchronized{ if (!initialized) { Unpickler.registerConstructor("array", "array", new ArrayConstructor()) diff --git a/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala b/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala index cf289fb3ae39f..8f30ff9202c83 100644 --- a/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala +++ b/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala @@ -18,38 +18,37 @@ package org.apache.spark.api.python import java.io.{DataOutput, DataInput} +import java.{util => ju} import com.google.common.base.Charsets.UTF_8 import org.apache.hadoop.io._ import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat + +import org.apache.spark.SparkException import org.apache.spark.api.java.JavaSparkContext -import org.apache.spark.{SparkContext, SparkException} /** * A class to test Pyrolite serialization on the Scala side, that will be deserialized * in Python - * @param str - * @param int - * @param double */ case class TestWritable(var str: String, var int: Int, var double: Double) extends Writable { def this() = this("", 0, 0.0) - def getStr = str + def getStr: String = str def setStr(str: String) { this.str = str } - def getInt = int + def getInt: Int = int def setInt(int: Int) { this.int = int } - def getDouble = double + def getDouble: Double = double def setDouble(double: Double) { this.double = double } - def write(out: DataOutput) = { + def write(out: DataOutput): Unit = { out.writeUTF(str) out.writeInt(int) out.writeDouble(double) } - def readFields(in: DataInput) = { + def readFields(in: DataInput): Unit = { str = in.readUTF() int = in.readInt() double = in.readDouble() @@ -57,28 +56,28 @@ case class TestWritable(var str: String, var int: Int, var double: Double) exten } private[python] class TestInputKeyConverter extends Converter[Any, Any] { - override def convert(obj: Any) = { + override def convert(obj: Any): Char = { obj.asInstanceOf[IntWritable].get().toChar } } private[python] class TestInputValueConverter extends Converter[Any, Any] { import collection.JavaConversions._ - override def convert(obj: Any) = { + override def convert(obj: Any): ju.List[Double] = { val m = obj.asInstanceOf[MapWritable] seqAsJavaList(m.keySet.map(w => w.asInstanceOf[DoubleWritable].get()).toSeq) } } private[python] class TestOutputKeyConverter extends Converter[Any, Any] { - override def convert(obj: Any) = { + override def convert(obj: Any): Text = { new Text(obj.asInstanceOf[Int].toString) } } private[python] class TestOutputValueConverter extends Converter[Any, Any] { import collection.JavaConversions._ - override def convert(obj: Any) = { + override def convert(obj: Any): DoubleWritable = { new DoubleWritable(obj.asInstanceOf[java.util.Map[Double, _]].keySet().head) } } @@ -86,7 +85,7 @@ private[python] class TestOutputValueConverter extends Converter[Any, Any] { private[python] class DoubleArrayWritable extends ArrayWritable(classOf[DoubleWritable]) private[python] class DoubleArrayToWritableConverter extends Converter[Any, Writable] { - override def convert(obj: Any) = obj match { + override def convert(obj: Any): DoubleArrayWritable = obj match { case arr if arr.getClass.isArray && arr.getClass.getComponentType == classOf[Double] => val daw = new DoubleArrayWritable daw.set(arr.asInstanceOf[Array[Double]].map(new DoubleWritable(_))) diff --git a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala index a5ea478f231d7..12d79f6ed311b 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala @@ -146,5 +146,5 @@ abstract class Broadcast[T: ClassTag](val id: Long) extends Serializable with Lo } } - override def toString = "Broadcast(" + id + ")" + override def toString: String = "Broadcast(" + id + ")" } diff --git a/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala index 8f8a0b11f9f2e..685313ac009ba 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala @@ -58,7 +58,7 @@ private[spark] class BroadcastManager( private val nextBroadcastId = new AtomicLong(0) - def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean) = { + def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean): Broadcast[T] = { broadcastFactory.newBroadcast[T](value_, isLocal, nextBroadcastId.getAndIncrement()) } diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala index 1444c0dd3d2d6..74ccfa6d3c9a3 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala @@ -160,7 +160,7 @@ private[broadcast] object HttpBroadcast extends Logging { logInfo("Broadcast server started at " + serverUri) } - def getFile(id: Long) = new File(broadcastDir, BroadcastBlockId(id).name) + def getFile(id: Long): File = new File(broadcastDir, BroadcastBlockId(id).name) private def write(id: Long, value: Any) { val file = getFile(id) @@ -222,7 +222,7 @@ private[broadcast] object HttpBroadcast extends Logging { * If removeFromDriver is true, also remove these persisted blocks on the driver * and delete the associated broadcast file. */ - def unpersist(id: Long, removeFromDriver: Boolean, blocking: Boolean) = synchronized { + def unpersist(id: Long, removeFromDriver: Boolean, blocking: Boolean): Unit = synchronized { SparkEnv.get.blockManager.master.removeBroadcast(id, removeFromDriver, blocking) if (removeFromDriver) { val file = getFile(id) diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala index c7ef02d572a19..cf3ae36f27949 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala @@ -31,7 +31,7 @@ class HttpBroadcastFactory extends BroadcastFactory { HttpBroadcast.initialize(isDriver, conf, securityMgr) } - override def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean, id: Long) = + override def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean, id: Long): Broadcast[T] = new HttpBroadcast[T](value_, isLocal, id) override def stop() { HttpBroadcast.stop() } diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index 94142d33369c7..23b02e60338fb 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -222,7 +222,7 @@ private object TorrentBroadcast extends Logging { * Remove all persisted blocks associated with this torrent broadcast on the executors. * If removeFromDriver is true, also remove these persisted blocks on the driver. */ - def unpersist(id: Long, removeFromDriver: Boolean, blocking: Boolean) = { + def unpersist(id: Long, removeFromDriver: Boolean, blocking: Boolean): Unit = { logDebug(s"Unpersisting TorrentBroadcast $id") SparkEnv.get.blockManager.master.removeBroadcast(id, removeFromDriver, blocking) } diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala index fb024c12094f2..96d8dd79908c8 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala @@ -30,7 +30,7 @@ class TorrentBroadcastFactory extends BroadcastFactory { override def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) { } - override def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean, id: Long) = { + override def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean, id: Long): Broadcast[T] = { new TorrentBroadcast[T](value_, id) } diff --git a/core/src/main/scala/org/apache/spark/deploy/Client.scala b/core/src/main/scala/org/apache/spark/deploy/Client.scala index 237d26fc6bd0e..65238af2caa24 100644 --- a/core/src/main/scala/org/apache/spark/deploy/Client.scala +++ b/core/src/main/scala/org/apache/spark/deploy/Client.scala @@ -38,7 +38,7 @@ private class ClientActor(driverArgs: ClientArguments, conf: SparkConf) var masterActor: ActorSelection = _ val timeout = AkkaUtils.askTimeout(conf) - override def preStart() = { + override def preStart(): Unit = { masterActor = context.actorSelection( Master.toAkkaUrl(driverArgs.master, AkkaUtils.protocol(context.system))) @@ -118,7 +118,7 @@ private class ClientActor(driverArgs: ClientArguments, conf: SparkConf) } } - override def receiveWithLogging = { + override def receiveWithLogging: PartialFunction[Any, Unit] = { case SubmitDriverResponse(success, driverId, message) => println(message) diff --git a/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala b/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala index 53bc62aff7395..5cbac787dceeb 100644 --- a/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala @@ -42,7 +42,7 @@ private[deploy] class ClientArguments(args: Array[String]) { var memory: Int = DEFAULT_MEMORY var cores: Int = DEFAULT_CORES private var _driverOptions = ListBuffer[String]() - def driverOptions = _driverOptions.toSeq + def driverOptions: Seq[String] = _driverOptions.toSeq // kill parameters var driverId: String = "" diff --git a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala index 7f600d89604a2..0997507d016f5 100644 --- a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala @@ -162,7 +162,7 @@ private[deploy] object DeployMessages { Utils.checkHost(host, "Required hostname") assert (port > 0) - def uri = "spark://" + host + ":" + port + def uri: String = "spark://" + host + ":" + port def restUri: Option[String] = restPort.map { p => "spark://" + host + ":" + p } } diff --git a/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala b/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala index 5668b53fc6f4f..a7c89276a045e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala +++ b/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala @@ -426,7 +426,7 @@ private object SparkDocker { } private class DockerId(val id: String) { - override def toString = id + override def toString: String = id } private object Docker extends Logging { diff --git a/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala index 458a7c3a455de..dfc5b97e6a6c8 100644 --- a/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala @@ -17,6 +17,7 @@ package org.apache.spark.deploy +import org.json4s.JsonAST.JObject import org.json4s.JsonDSL._ import org.apache.spark.deploy.DeployMessages.{MasterStateResponse, WorkerStateResponse} @@ -24,7 +25,7 @@ import org.apache.spark.deploy.master.{ApplicationInfo, DriverInfo, WorkerInfo} import org.apache.spark.deploy.worker.ExecutorRunner private[deploy] object JsonProtocol { - def writeWorkerInfo(obj: WorkerInfo) = { + def writeWorkerInfo(obj: WorkerInfo): JObject = { ("id" -> obj.id) ~ ("host" -> obj.host) ~ ("port" -> obj.port) ~ @@ -39,7 +40,7 @@ private[deploy] object JsonProtocol { ("lastheartbeat" -> obj.lastHeartbeat) } - def writeApplicationInfo(obj: ApplicationInfo) = { + def writeApplicationInfo(obj: ApplicationInfo): JObject = { ("starttime" -> obj.startTime) ~ ("id" -> obj.id) ~ ("name" -> obj.desc.name) ~ @@ -51,7 +52,7 @@ private[deploy] object JsonProtocol { ("duration" -> obj.duration) } - def writeApplicationDescription(obj: ApplicationDescription) = { + def writeApplicationDescription(obj: ApplicationDescription): JObject = { ("name" -> obj.name) ~ ("cores" -> obj.maxCores) ~ ("memoryperslave" -> obj.memoryPerSlave) ~ @@ -59,14 +60,14 @@ private[deploy] object JsonProtocol { ("command" -> obj.command.toString) } - def writeExecutorRunner(obj: ExecutorRunner) = { + def writeExecutorRunner(obj: ExecutorRunner): JObject = { ("id" -> obj.execId) ~ ("memory" -> obj.memory) ~ ("appid" -> obj.appId) ~ ("appdesc" -> writeApplicationDescription(obj.appDesc)) } - def writeDriverInfo(obj: DriverInfo) = { + def writeDriverInfo(obj: DriverInfo): JObject = { ("id" -> obj.id) ~ ("starttime" -> obj.startTime.toString) ~ ("state" -> obj.state.toString) ~ @@ -74,7 +75,7 @@ private[deploy] object JsonProtocol { ("memory" -> obj.desc.mem) } - def writeMasterState(obj: MasterStateResponse) = { + def writeMasterState(obj: MasterStateResponse): JObject = { ("url" -> obj.uri) ~ ("workers" -> obj.workers.toList.map(writeWorkerInfo)) ~ ("cores" -> obj.workers.map(_.cores).sum) ~ @@ -87,7 +88,7 @@ private[deploy] object JsonProtocol { ("status" -> obj.status.toString) } - def writeWorkerState(obj: WorkerStateResponse) = { + def writeWorkerState(obj: WorkerStateResponse): JObject = { ("id" -> obj.workerId) ~ ("masterurl" -> obj.masterUrl) ~ ("masterwebuiurl" -> obj.masterWebUiUrl) ~ diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index e0a32fb65cd51..c2568eb4b60ac 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -193,7 +193,7 @@ class SparkHadoopUtil extends Logging { * that file. */ def listLeafStatuses(fs: FileSystem, basePath: Path): Seq[FileStatus] = { - def recurse(path: Path) = { + def recurse(path: Path): Array[FileStatus] = { val (directories, leaves) = fs.listStatus(path).partition(_.isDir) leaves ++ directories.flatMap(f => listLeafStatuses(fs, f.getPath)) } diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 4f506be63fe59..660307d19eab4 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -777,7 +777,7 @@ private[deploy] object SparkSubmitUtils { } /** A nice function to use in tests as well. Values are dummy strings. */ - def getModuleDescriptor = DefaultModuleDescriptor.newDefaultInstance( + def getModuleDescriptor: DefaultModuleDescriptor = DefaultModuleDescriptor.newDefaultInstance( ModuleRevisionId.newInstance("org.apache.spark", "spark-submit-parent", "1.0")) /** diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index 2250d5a28e4ef..6eb73c43470a5 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -252,7 +252,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S master.startsWith("spark://") && deployMode == "cluster" } - override def toString = { + override def toString: String = { s"""Parsed arguments: | master $master | deployMode $deployMode diff --git a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala index 2d24083a77b73..3b729725257ef 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala @@ -116,7 +116,7 @@ private[spark] class AppClient( masterAkkaUrls.map(AddressFromURIString(_).hostPort).contains(remoteUrl.hostPort) } - override def receiveWithLogging = { + override def receiveWithLogging: PartialFunction[Any, Unit] = { case RegisteredApplication(appId_, masterUrl) => appId = appId_ registered = true diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index db7c499661319..80c9c13ddec1e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -93,7 +93,7 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis */ private def getRunner(operateFun: () => Unit): Runnable = { new Runnable() { - override def run() = Utils.tryOrExit { + override def run(): Unit = Utils.tryOrExit { operateFun() } } @@ -141,7 +141,7 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis } } - override def getListing() = applications.values + override def getListing(): Iterable[FsApplicationHistoryInfo] = applications.values override def getAppUI(appId: String): Option[SparkUI] = { try { diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala index af483d560b33e..72f6048239297 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala @@ -61,7 +61,7 @@ class HistoryServer( private val appCache = CacheBuilder.newBuilder() .maximumSize(retainedApplications) .removalListener(new RemovalListener[String, SparkUI] { - override def onRemoval(rm: RemovalNotification[String, SparkUI]) = { + override def onRemoval(rm: RemovalNotification[String, SparkUI]): Unit = { detachSparkUI(rm.getValue()) } }) @@ -149,14 +149,14 @@ class HistoryServer( * * @return List of all known applications. */ - def getApplicationList() = provider.getListing() + def getApplicationList(): Iterable[ApplicationHistoryInfo] = provider.getListing() /** * Returns the provider configuration to show in the listing page. * * @return A map with the provider's configuration. */ - def getProviderConfig() = provider.getConfig() + def getProviderConfig(): Map[String, String] = provider.getConfig() } @@ -195,9 +195,7 @@ object HistoryServer extends Logging { server.bind() Runtime.getRuntime().addShutdownHook(new Thread("HistoryServerStopper") { - override def run() = { - server.stop() - } + override def run(): Unit = server.stop() }) // Wait until the end of the world... or if the HistoryServer process is manually stopped diff --git a/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala index d2d30bfd7fcba..32499b3a784a1 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala @@ -48,7 +48,7 @@ private[master] class FileSystemPersistenceEngine( new File(dir + File.separator + name).delete() } - override def read[T: ClassTag](prefix: String) = { + override def read[T: ClassTag](prefix: String): Seq[T] = { val files = new File(dir).listFiles().filter(_.getName.startsWith(prefix)) files.map(deserializeFromFile[T]) } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 1b42121c8db05..80506621f4d24 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -204,7 +204,7 @@ private[master] class Master( self ! RevokedLeadership } - override def receiveWithLogging = { + override def receiveWithLogging: PartialFunction[Any, Unit] = { case ElectedLeader => { val (storedApps, storedDrivers, storedWorkers) = persistenceEngine.readPersistedData() state = if (storedApps.isEmpty && storedDrivers.isEmpty && storedWorkers.isEmpty) { diff --git a/core/src/main/scala/org/apache/spark/deploy/master/RecoveryModeFactory.scala b/core/src/main/scala/org/apache/spark/deploy/master/RecoveryModeFactory.scala index 1583bf1f60032..351db8fab2041 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/RecoveryModeFactory.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/RecoveryModeFactory.scala @@ -51,20 +51,27 @@ abstract class StandaloneRecoveryModeFactory(conf: SparkConf, serializer: Serial */ private[master] class FileSystemRecoveryModeFactory(conf: SparkConf, serializer: Serialization) extends StandaloneRecoveryModeFactory(conf, serializer) with Logging { + val RECOVERY_DIR = conf.get("spark.deploy.recoveryDirectory", "") - def createPersistenceEngine() = { + def createPersistenceEngine(): PersistenceEngine = { logInfo("Persisting recovery state to directory: " + RECOVERY_DIR) new FileSystemPersistenceEngine(RECOVERY_DIR, serializer) } - def createLeaderElectionAgent(master: LeaderElectable) = new MonarchyLeaderAgent(master) + def createLeaderElectionAgent(master: LeaderElectable): LeaderElectionAgent = { + new MonarchyLeaderAgent(master) + } } private[master] class ZooKeeperRecoveryModeFactory(conf: SparkConf, serializer: Serialization) extends StandaloneRecoveryModeFactory(conf, serializer) { - def createPersistenceEngine() = new ZooKeeperPersistenceEngine(conf, serializer) - def createLeaderElectionAgent(master: LeaderElectable) = + def createPersistenceEngine(): PersistenceEngine = { + new ZooKeeperPersistenceEngine(conf, serializer) + } + + def createLeaderElectionAgent(master: LeaderElectable): LeaderElectionAgent = { new ZooKeeperLeaderElectionAgent(master, conf) + } } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala index e94aae93e4495..9b3d48c6edc84 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala @@ -104,7 +104,7 @@ private[spark] class WorkerInfo( "http://" + this.publicAddress + ":" + this.webUiPort } - def setState(state: WorkerState.Value) = { + def setState(state: WorkerState.Value): Unit = { this.state = state } } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala index 1ac6677ad2b6d..a285783f72000 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala @@ -46,7 +46,7 @@ private[master] class ZooKeeperPersistenceEngine(conf: SparkConf, val serializat zk.delete().forPath(WORKING_DIR + "/" + name) } - override def read[T: ClassTag](prefix: String) = { + override def read[T: ClassTag](prefix: String): Seq[T] = { val file = zk.getChildren.forPath(WORKING_DIR).filter(_.startsWith(prefix)) file.map(deserializeFromFile[T]).flatten } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala index dee2e4a447c6e..46509e39c0f23 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala @@ -95,7 +95,7 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { // For now we only show driver information if the user has submitted drivers to the cluster. // This is until we integrate the notion of drivers and applications in the UI. - def hasDrivers = activeDrivers.length > 0 || completedDrivers.length > 0 + def hasDrivers: Boolean = activeDrivers.length > 0 || completedDrivers.length > 0 val content =
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala index 27a9eabb1ede7..e0948e16ef354 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala @@ -56,8 +56,14 @@ private[deploy] class DriverRunner( private var finalExitCode: Option[Int] = None // Decoupled for testing - def setClock(_clock: Clock) = clock = _clock - def setSleeper(_sleeper: Sleeper) = sleeper = _sleeper + def setClock(_clock: Clock): Unit = { + clock = _clock + } + + def setSleeper(_sleeper: Sleeper): Unit = { + sleeper = _sleeper + } + private var clock: Clock = new SystemClock() private var sleeper = new Sleeper { def sleep(seconds: Int): Unit = (0 until seconds).takeWhile(f => {Thread.sleep(1000); !killed}) @@ -155,7 +161,7 @@ private[deploy] class DriverRunner( private def launchDriver(builder: ProcessBuilder, baseDir: File, supervise: Boolean) { builder.directory(baseDir) - def initialize(process: Process) = { + def initialize(process: Process): Unit = { // Redirect stdout and stderr to files val stdout = new File(baseDir, "stdout") CommandUtils.redirectStream(process.getInputStream, stdout) @@ -169,8 +175,8 @@ private[deploy] class DriverRunner( runCommandWithRetry(ProcessBuilderLike(builder), initialize, supervise) } - def runCommandWithRetry(command: ProcessBuilderLike, initialize: Process => Unit, - supervise: Boolean) { + def runCommandWithRetry( + command: ProcessBuilderLike, initialize: Process => Unit, supervise: Boolean): Unit = { // Time to wait between submission retries. var waitSeconds = 1 // A run of this many seconds resets the exponential back-off. @@ -216,8 +222,8 @@ private[deploy] trait ProcessBuilderLike { } private[deploy] object ProcessBuilderLike { - def apply(processBuilder: ProcessBuilder) = new ProcessBuilderLike { - def start() = processBuilder.start() - def command = processBuilder.command() + def apply(processBuilder: ProcessBuilder): ProcessBuilderLike = new ProcessBuilderLike { + override def start(): Process = processBuilder.start() + override def command: Seq[String] = processBuilder.command() } } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index c1b0a295f9f74..c4c24a7866aa3 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -275,7 +275,7 @@ private[worker] class Worker( } } - override def receiveWithLogging = { + override def receiveWithLogging: PartialFunction[Any, Unit] = { case RegisteredWorker(masterUrl, masterWebUiUrl) => logInfo("Successfully registered with master " + masterUrl) registered = true diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala index 09d866fb0cd90..e0790274d7d3e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala @@ -50,7 +50,7 @@ private[spark] class WorkerWatcher(workerUrl: String) private def exitNonZero() = if (isTesting) isShutDown = true else System.exit(-1) - override def receiveWithLogging = { + override def receiveWithLogging: PartialFunction[Any, Unit] = { case AssociatedEvent(localAddress, remoteAddress, inbound) if isWorker(remoteAddress) => logInfo(s"Successfully connected to $workerUrl") diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index dd19e4947db1e..b5205d4e997ae 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -62,7 +62,7 @@ private[spark] class CoarseGrainedExecutorBackend( .map(e => (e._1.substring(prefix.length).toLowerCase, e._2)) } - override def receiveWithLogging = { + override def receiveWithLogging: PartialFunction[Any, Unit] = { case RegisteredExecutor => logInfo("Successfully registered with driver") val (hostname, _) = Utils.parseHostPort(hostPort) diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorActor.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorActor.scala index 41925f7e97e84..3e47d13f7545d 100644 --- a/core/src/main/scala/org/apache/spark/executor/ExecutorActor.scala +++ b/core/src/main/scala/org/apache/spark/executor/ExecutorActor.scala @@ -33,7 +33,7 @@ private[spark] case object TriggerThreadDump private[spark] class ExecutorActor(executorId: String) extends Actor with ActorLogReceive with Logging { - override def receiveWithLogging = { + override def receiveWithLogging: PartialFunction[Any, Unit] = { case TriggerThreadDump => sender ! Utils.getThreadDump() } diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala index 07b152651dedf..06152f16ae618 100644 --- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala @@ -17,13 +17,10 @@ package org.apache.spark.executor -import java.util.concurrent.atomic.AtomicLong - -import org.apache.spark.executor.DataReadMethod.DataReadMethod - import scala.collection.mutable.ArrayBuffer import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.executor.DataReadMethod.DataReadMethod import org.apache.spark.storage.{BlockId, BlockStatus} /** @@ -44,14 +41,14 @@ class TaskMetrics extends Serializable { * Host's name the task runs on */ private var _hostname: String = _ - def hostname = _hostname + def hostname: String = _hostname private[spark] def setHostname(value: String) = _hostname = value /** * Time taken on the executor to deserialize this task */ private var _executorDeserializeTime: Long = _ - def executorDeserializeTime = _executorDeserializeTime + def executorDeserializeTime: Long = _executorDeserializeTime private[spark] def setExecutorDeserializeTime(value: Long) = _executorDeserializeTime = value @@ -59,14 +56,14 @@ class TaskMetrics extends Serializable { * Time the executor spends actually running the task (including fetching shuffle data) */ private var _executorRunTime: Long = _ - def executorRunTime = _executorRunTime + def executorRunTime: Long = _executorRunTime private[spark] def setExecutorRunTime(value: Long) = _executorRunTime = value /** * The number of bytes this task transmitted back to the driver as the TaskResult */ private var _resultSize: Long = _ - def resultSize = _resultSize + def resultSize: Long = _resultSize private[spark] def setResultSize(value: Long) = _resultSize = value @@ -74,31 +71,31 @@ class TaskMetrics extends Serializable { * Amount of time the JVM spent in garbage collection while executing this task */ private var _jvmGCTime: Long = _ - def jvmGCTime = _jvmGCTime + def jvmGCTime: Long = _jvmGCTime private[spark] def setJvmGCTime(value: Long) = _jvmGCTime = value /** * Amount of time spent serializing the task result */ private var _resultSerializationTime: Long = _ - def resultSerializationTime = _resultSerializationTime + def resultSerializationTime: Long = _resultSerializationTime private[spark] def setResultSerializationTime(value: Long) = _resultSerializationTime = value /** * The number of in-memory bytes spilled by this task */ private var _memoryBytesSpilled: Long = _ - def memoryBytesSpilled = _memoryBytesSpilled - private[spark] def incMemoryBytesSpilled(value: Long) = _memoryBytesSpilled += value - private[spark] def decMemoryBytesSpilled(value: Long) = _memoryBytesSpilled -= value + def memoryBytesSpilled: Long = _memoryBytesSpilled + private[spark] def incMemoryBytesSpilled(value: Long): Unit = _memoryBytesSpilled += value + private[spark] def decMemoryBytesSpilled(value: Long): Unit = _memoryBytesSpilled -= value /** * The number of on-disk bytes spilled by this task */ private var _diskBytesSpilled: Long = _ - def diskBytesSpilled = _diskBytesSpilled - def incDiskBytesSpilled(value: Long) = _diskBytesSpilled += value - def decDiskBytesSpilled(value: Long) = _diskBytesSpilled -= value + def diskBytesSpilled: Long = _diskBytesSpilled + def incDiskBytesSpilled(value: Long): Unit = _diskBytesSpilled += value + def decDiskBytesSpilled(value: Long): Unit = _diskBytesSpilled -= value /** * If this task reads from a HadoopRDD or from persisted data, metrics on how much data was read @@ -106,7 +103,7 @@ class TaskMetrics extends Serializable { */ private var _inputMetrics: Option[InputMetrics] = None - def inputMetrics = _inputMetrics + def inputMetrics: Option[InputMetrics] = _inputMetrics /** * This should only be used when recreating TaskMetrics, not when updating input metrics in @@ -128,7 +125,7 @@ class TaskMetrics extends Serializable { */ private var _shuffleReadMetrics: Option[ShuffleReadMetrics] = None - def shuffleReadMetrics = _shuffleReadMetrics + def shuffleReadMetrics: Option[ShuffleReadMetrics] = _shuffleReadMetrics /** * This should only be used when recreating TaskMetrics, not when updating read metrics in @@ -177,17 +174,18 @@ class TaskMetrics extends Serializable { * Once https://issues.apache.org/jira/browse/SPARK-5225 is addressed, * we can store all the different inputMetrics (one per readMethod). */ - private[spark] def getInputMetricsForReadMethod( - readMethod: DataReadMethod): InputMetrics = synchronized { - _inputMetrics match { - case None => - val metrics = new InputMetrics(readMethod) - _inputMetrics = Some(metrics) - metrics - case Some(metrics @ InputMetrics(method)) if method == readMethod => - metrics - case Some(InputMetrics(method)) => - new InputMetrics(readMethod) + private[spark] def getInputMetricsForReadMethod(readMethod: DataReadMethod): InputMetrics = { + synchronized { + _inputMetrics match { + case None => + val metrics = new InputMetrics(readMethod) + _inputMetrics = Some(metrics) + metrics + case Some(metrics @ InputMetrics(method)) if method == readMethod => + metrics + case Some(InputMetrics(method)) => + new InputMetrics(readMethod) + } } } @@ -256,14 +254,14 @@ case class InputMetrics(readMethod: DataReadMethod.Value) { */ private var _bytesRead: Long = _ def bytesRead: Long = _bytesRead - def incBytesRead(bytes: Long) = _bytesRead += bytes + def incBytesRead(bytes: Long): Unit = _bytesRead += bytes /** * Total records read. */ private var _recordsRead: Long = _ def recordsRead: Long = _recordsRead - def incRecordsRead(records: Long) = _recordsRead += records + def incRecordsRead(records: Long): Unit = _recordsRead += records /** * Invoke the bytesReadCallback and mutate bytesRead. @@ -293,15 +291,15 @@ case class OutputMetrics(writeMethod: DataWriteMethod.Value) { * Total bytes written */ private var _bytesWritten: Long = _ - def bytesWritten = _bytesWritten - private[spark] def setBytesWritten(value : Long) = _bytesWritten = value + def bytesWritten: Long = _bytesWritten + private[spark] def setBytesWritten(value : Long): Unit = _bytesWritten = value /** * Total records written */ private var _recordsWritten: Long = 0L - def recordsWritten = _recordsWritten - private[spark] def setRecordsWritten(value: Long) = _recordsWritten = value + def recordsWritten: Long = _recordsWritten + private[spark] def setRecordsWritten(value: Long): Unit = _recordsWritten = value } /** @@ -314,7 +312,7 @@ class ShuffleReadMetrics extends Serializable { * Number of remote blocks fetched in this shuffle by this task */ private var _remoteBlocksFetched: Int = _ - def remoteBlocksFetched = _remoteBlocksFetched + def remoteBlocksFetched: Int = _remoteBlocksFetched private[spark] def incRemoteBlocksFetched(value: Int) = _remoteBlocksFetched += value private[spark] def decRemoteBlocksFetched(value: Int) = _remoteBlocksFetched -= value @@ -322,7 +320,7 @@ class ShuffleReadMetrics extends Serializable { * Number of local blocks fetched in this shuffle by this task */ private var _localBlocksFetched: Int = _ - def localBlocksFetched = _localBlocksFetched + def localBlocksFetched: Int = _localBlocksFetched private[spark] def incLocalBlocksFetched(value: Int) = _localBlocksFetched += value private[spark] def decLocalBlocksFetched(value: Int) = _localBlocksFetched -= value @@ -332,7 +330,7 @@ class ShuffleReadMetrics extends Serializable { * still not finished processing block A, it is not considered to be blocking on block B. */ private var _fetchWaitTime: Long = _ - def fetchWaitTime = _fetchWaitTime + def fetchWaitTime: Long = _fetchWaitTime private[spark] def incFetchWaitTime(value: Long) = _fetchWaitTime += value private[spark] def decFetchWaitTime(value: Long) = _fetchWaitTime -= value @@ -340,7 +338,7 @@ class ShuffleReadMetrics extends Serializable { * Total number of remote bytes read from the shuffle by this task */ private var _remoteBytesRead: Long = _ - def remoteBytesRead = _remoteBytesRead + def remoteBytesRead: Long = _remoteBytesRead private[spark] def incRemoteBytesRead(value: Long) = _remoteBytesRead += value private[spark] def decRemoteBytesRead(value: Long) = _remoteBytesRead -= value @@ -348,24 +346,24 @@ class ShuffleReadMetrics extends Serializable { * Shuffle data that was read from the local disk (as opposed to from a remote executor). */ private var _localBytesRead: Long = _ - def localBytesRead = _localBytesRead + def localBytesRead: Long = _localBytesRead private[spark] def incLocalBytesRead(value: Long) = _localBytesRead += value /** * Total bytes fetched in the shuffle by this task (both remote and local). */ - def totalBytesRead = _remoteBytesRead + _localBytesRead + def totalBytesRead: Long = _remoteBytesRead + _localBytesRead /** * Number of blocks fetched in this shuffle by this task (remote or local) */ - def totalBlocksFetched = _remoteBlocksFetched + _localBlocksFetched + def totalBlocksFetched: Int = _remoteBlocksFetched + _localBlocksFetched /** * Total number of records read from the shuffle by this task */ private var _recordsRead: Long = _ - def recordsRead = _recordsRead + def recordsRead: Long = _recordsRead private[spark] def incRecordsRead(value: Long) = _recordsRead += value private[spark] def decRecordsRead(value: Long) = _recordsRead -= value } @@ -380,7 +378,7 @@ class ShuffleWriteMetrics extends Serializable { * Number of bytes written for the shuffle by this task */ @volatile private var _shuffleBytesWritten: Long = _ - def shuffleBytesWritten = _shuffleBytesWritten + def shuffleBytesWritten: Long = _shuffleBytesWritten private[spark] def incShuffleBytesWritten(value: Long) = _shuffleBytesWritten += value private[spark] def decShuffleBytesWritten(value: Long) = _shuffleBytesWritten -= value @@ -388,7 +386,7 @@ class ShuffleWriteMetrics extends Serializable { * Time the task spent blocking on writes to disk or buffer cache, in nanoseconds */ @volatile private var _shuffleWriteTime: Long = _ - def shuffleWriteTime= _shuffleWriteTime + def shuffleWriteTime: Long = _shuffleWriteTime private[spark] def incShuffleWriteTime(value: Long) = _shuffleWriteTime += value private[spark] def decShuffleWriteTime(value: Long) = _shuffleWriteTime -= value @@ -396,7 +394,7 @@ class ShuffleWriteMetrics extends Serializable { * Total number of records written to the shuffle by this task */ @volatile private var _shuffleRecordsWritten: Long = _ - def shuffleRecordsWritten = _shuffleRecordsWritten + def shuffleRecordsWritten: Long = _shuffleRecordsWritten private[spark] def incShuffleRecordsWritten(value: Long) = _shuffleRecordsWritten += value private[spark] def decShuffleRecordsWritten(value: Long) = _shuffleRecordsWritten -= value private[spark] def setShuffleRecordsWritten(value: Long) = _shuffleRecordsWritten = value diff --git a/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala b/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala index 593a62b3e3b32..6cda7772f77bc 100644 --- a/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala +++ b/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala @@ -73,16 +73,16 @@ private[spark] abstract class StreamBasedRecordReader[T]( private var key = "" private var value: T = null.asInstanceOf[T] - override def initialize(split: InputSplit, context: TaskAttemptContext) = {} - override def close() = {} + override def initialize(split: InputSplit, context: TaskAttemptContext): Unit = {} + override def close(): Unit = {} - override def getProgress = if (processed) 1.0f else 0.0f + override def getProgress: Float = if (processed) 1.0f else 0.0f - override def getCurrentKey = key + override def getCurrentKey: String = key - override def getCurrentValue = value + override def getCurrentValue: T = value - override def nextKeyValue = { + override def nextKeyValue: Boolean = { if (!processed) { val fileIn = new PortableDataStream(split, context, index) value = parseStream(fileIn) @@ -119,7 +119,8 @@ private[spark] class StreamRecordReader( * The format for the PortableDataStream files */ private[spark] class StreamInputFormat extends StreamFileInputFormat[PortableDataStream] { - override def createRecordReader(split: InputSplit, taContext: TaskAttemptContext) = { + override def createRecordReader(split: InputSplit, taContext: TaskAttemptContext) + : CombineFileRecordReader[String, PortableDataStream] = { new CombineFileRecordReader[String, PortableDataStream]( split.asInstanceOf[CombineFileSplit], taContext, classOf[StreamRecordReader]) } @@ -204,7 +205,7 @@ class PortableDataStream( /** * Close the file (if it is currently open) */ - def close() = { + def close(): Unit = { if (isOpen) { try { fileIn.close() diff --git a/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala b/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala index 21b782edd2a9e..87c2aa481095d 100644 --- a/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala +++ b/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala @@ -52,7 +52,7 @@ trait SparkHadoopMapRedUtil { jobId: Int, isMap: Boolean, taskId: Int, - attemptId: Int) = { + attemptId: Int): TaskAttemptID = { new TaskAttemptID(jtIdentifier, jobId, isMap, taskId, attemptId) } diff --git a/core/src/main/scala/org/apache/spark/mapreduce/SparkHadoopMapReduceUtil.scala b/core/src/main/scala/org/apache/spark/mapreduce/SparkHadoopMapReduceUtil.scala index 3340673f91156..cfd20392d12f1 100644 --- a/core/src/main/scala/org/apache/spark/mapreduce/SparkHadoopMapReduceUtil.scala +++ b/core/src/main/scala/org/apache/spark/mapreduce/SparkHadoopMapReduceUtil.scala @@ -45,7 +45,7 @@ trait SparkHadoopMapReduceUtil { jobId: Int, isMap: Boolean, taskId: Int, - attemptId: Int) = { + attemptId: Int): TaskAttemptID = { val klass = Class.forName("org.apache.hadoop.mapreduce.TaskAttemptID") try { // First, attempt to use the old-style constructor that takes a boolean isMap diff --git a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala index 345db36630fd5..9150ad35712a1 100644 --- a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala +++ b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala @@ -23,6 +23,7 @@ import java.util.concurrent.TimeUnit import scala.collection.mutable import com.codahale.metrics.{Metric, MetricFilter, MetricRegistry} +import org.eclipse.jetty.servlet.ServletContextHandler import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.metrics.sink.{MetricsServlet, Sink} @@ -84,7 +85,7 @@ private[spark] class MetricsSystem private ( /** * Get any UI handlers used by this metrics system; can only be called after start(). */ - def getServletHandlers = { + def getServletHandlers: Array[ServletContextHandler] = { require(running, "Can only call getServletHandlers on a running MetricsSystem") metricsServlet.map(_.getHandlers).getOrElse(Array()) } diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala b/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala index 2f65bc8b46609..0c2e212a33074 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala @@ -30,8 +30,12 @@ import org.eclipse.jetty.servlet.ServletContextHandler import org.apache.spark.SecurityManager import org.apache.spark.ui.JettyUtils._ -private[spark] class MetricsServlet(val property: Properties, val registry: MetricRegistry, - securityMgr: SecurityManager) extends Sink { +private[spark] class MetricsServlet( + val property: Properties, + val registry: MetricRegistry, + securityMgr: SecurityManager) + extends Sink { + val SERVLET_KEY_PATH = "path" val SERVLET_KEY_SAMPLE = "sample" @@ -45,10 +49,12 @@ private[spark] class MetricsServlet(val property: Properties, val registry: Metr val mapper = new ObjectMapper().registerModule( new MetricsModule(TimeUnit.SECONDS, TimeUnit.MILLISECONDS, servletShowSample)) - def getHandlers = Array[ServletContextHandler]( - createServletHandler(servletPath, - new ServletParams(request => getMetricsSnapshot(request), "text/json"), securityMgr) - ) + def getHandlers: Array[ServletContextHandler] = { + Array[ServletContextHandler]( + createServletHandler(servletPath, + new ServletParams(request => getMetricsSnapshot(request), "text/json"), securityMgr) + ) + } def getMetricsSnapshot(request: HttpServletRequest): String = { mapper.writeValueAsString(registry) diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/Sink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/Sink.scala index 0d83d8c425ca4..9fad4e7deacb6 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/Sink.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/Sink.scala @@ -18,7 +18,7 @@ package org.apache.spark.metrics.sink private[spark] trait Sink { - def start: Unit - def stop: Unit + def start(): Unit + def stop(): Unit def report(): Unit } diff --git a/core/src/main/scala/org/apache/spark/network/nio/BlockMessageArray.scala b/core/src/main/scala/org/apache/spark/network/nio/BlockMessageArray.scala index a1a2c00ed1542..1ba25aa74aa02 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/BlockMessageArray.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/BlockMessageArray.scala @@ -32,11 +32,11 @@ class BlockMessageArray(var blockMessages: Seq[BlockMessage]) def this() = this(null.asInstanceOf[Seq[BlockMessage]]) - def apply(i: Int) = blockMessages(i) + def apply(i: Int): BlockMessage = blockMessages(i) - def iterator = blockMessages.iterator + def iterator: Iterator[BlockMessage] = blockMessages.iterator - def length = blockMessages.length + def length: Int = blockMessages.length def set(bufferMessage: BufferMessage) { val startTime = System.currentTimeMillis diff --git a/core/src/main/scala/org/apache/spark/network/nio/BufferMessage.scala b/core/src/main/scala/org/apache/spark/network/nio/BufferMessage.scala index 3b245c5c7a4f3..9a9e22b0c2366 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/BufferMessage.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/BufferMessage.scala @@ -31,9 +31,9 @@ class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId: val initialSize = currentSize() var gotChunkForSendingOnce = false - def size = initialSize + def size: Int = initialSize - def currentSize() = { + def currentSize(): Int = { if (buffers == null || buffers.isEmpty) { 0 } else { @@ -100,11 +100,11 @@ class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId: buffers.foreach(_.flip) } - def hasAckId() = (ackId != 0) + def hasAckId(): Boolean = ackId != 0 - def isCompletelyReceived() = !buffers(0).hasRemaining + def isCompletelyReceived: Boolean = !buffers(0).hasRemaining - override def toString = { + override def toString: String = { if (hasAckId) { "BufferAckMessage(aid = " + ackId + ", id = " + id + ", size = " + size + ")" } else { diff --git a/core/src/main/scala/org/apache/spark/network/nio/Connection.scala b/core/src/main/scala/org/apache/spark/network/nio/Connection.scala index c2d9578be7ebb..04eb2bf9ba4ab 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/Connection.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/Connection.scala @@ -101,9 +101,11 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector, socketRemoteConnectionManagerId } - def key() = channel.keyFor(selector) + def key(): SelectionKey = channel.keyFor(selector) - def getRemoteAddress() = channel.socket.getRemoteSocketAddress().asInstanceOf[InetSocketAddress] + def getRemoteAddress(): InetSocketAddress = { + channel.socket.getRemoteSocketAddress().asInstanceOf[InetSocketAddress] + } // Returns whether we have to register for further reads or not. def read(): Boolean = { @@ -280,7 +282,7 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector, /* channel.socket.setSendBufferSize(256 * 1024) */ - override def getRemoteAddress() = address + override def getRemoteAddress(): InetSocketAddress = address val DEFAULT_INTEREST = SelectionKey.OP_READ diff --git a/core/src/main/scala/org/apache/spark/network/nio/ConnectionId.scala b/core/src/main/scala/org/apache/spark/network/nio/ConnectionId.scala index 764dc5e5503ed..b3b281ff465f1 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/ConnectionId.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/ConnectionId.scala @@ -18,7 +18,9 @@ package org.apache.spark.network.nio private[nio] case class ConnectionId(connectionManagerId: ConnectionManagerId, uniqId: Int) { - override def toString = connectionManagerId.host + "_" + connectionManagerId.port + "_" + uniqId + override def toString: String = { + connectionManagerId.host + "_" + connectionManagerId.port + "_" + uniqId + } } private[nio] object ConnectionId { diff --git a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala index ee22c6656e69e..741fe3e1ea750 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala @@ -188,7 +188,7 @@ private[nio] class ConnectionManager( private val readRunnableStarted: HashSet[SelectionKey] = new HashSet[SelectionKey]() private val selectorThread = new Thread("connection-manager-thread") { - override def run() = ConnectionManager.this.run() + override def run(): Unit = ConnectionManager.this.run() } selectorThread.setDaemon(true) // start this thread last, since it invokes run(), which accesses members above diff --git a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManagerId.scala b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManagerId.scala index cbb37ec5ced1f..1cd13d887c6f6 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManagerId.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManagerId.scala @@ -26,7 +26,7 @@ private[nio] case class ConnectionManagerId(host: String, port: Int) { Utils.checkHost(host) assert (port > 0) - def toSocketAddress() = new InetSocketAddress(host, port) + def toSocketAddress(): InetSocketAddress = new InetSocketAddress(host, port) } diff --git a/core/src/main/scala/org/apache/spark/network/nio/Message.scala b/core/src/main/scala/org/apache/spark/network/nio/Message.scala index fb4a979b824c3..85d2fe2bf9c20 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/Message.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/Message.scala @@ -42,7 +42,9 @@ private[nio] abstract class Message(val typ: Long, val id: Int) { def timeTaken(): String = (finishTime - startTime).toString + " ms" - override def toString = this.getClass.getSimpleName + "(id = " + id + ", size = " + size + ")" + override def toString: String = { + this.getClass.getSimpleName + "(id = " + id + ", size = " + size + ")" + } } @@ -51,7 +53,7 @@ private[nio] object Message { var lastId = 1 - def getNewId() = synchronized { + def getNewId(): Int = synchronized { lastId += 1 if (lastId == 0) { lastId += 1 diff --git a/core/src/main/scala/org/apache/spark/network/nio/MessageChunk.scala b/core/src/main/scala/org/apache/spark/network/nio/MessageChunk.scala index 278c5ac356ef2..a4568e849fa13 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/MessageChunk.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/MessageChunk.scala @@ -24,9 +24,9 @@ import scala.collection.mutable.ArrayBuffer private[nio] class MessageChunk(val header: MessageChunkHeader, val buffer: ByteBuffer) { - val size = if (buffer == null) 0 else buffer.remaining + val size: Int = if (buffer == null) 0 else buffer.remaining - lazy val buffers = { + lazy val buffers: ArrayBuffer[ByteBuffer] = { val ab = new ArrayBuffer[ByteBuffer]() ab += header.buffer if (buffer != null) { @@ -35,7 +35,7 @@ class MessageChunk(val header: MessageChunkHeader, val buffer: ByteBuffer) { ab } - override def toString = { + override def toString: String = { "" + this.getClass.getSimpleName + " (id = " + header.id + ", size = " + size + ")" } } diff --git a/core/src/main/scala/org/apache/spark/network/nio/MessageChunkHeader.scala b/core/src/main/scala/org/apache/spark/network/nio/MessageChunkHeader.scala index 6e20f291c5cec..7b3da4bb9d5ee 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/MessageChunkHeader.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/MessageChunkHeader.scala @@ -50,8 +50,10 @@ private[nio] class MessageChunkHeader( flip.asInstanceOf[ByteBuffer] } - override def toString = "" + this.getClass.getSimpleName + ":" + id + " of type " + typ + + override def toString: String = { + "" + this.getClass.getSimpleName + ":" + id + " of type " + typ + " and sizes " + totalSize + " / " + chunkSize + " bytes, securityNeg: " + securityNeg + } } diff --git a/core/src/main/scala/org/apache/spark/partial/PartialResult.scala b/core/src/main/scala/org/apache/spark/partial/PartialResult.scala index cadd0c7ed19ba..53c4b32c95ab3 100644 --- a/core/src/main/scala/org/apache/spark/partial/PartialResult.scala +++ b/core/src/main/scala/org/apache/spark/partial/PartialResult.scala @@ -99,7 +99,7 @@ class PartialResult[R](initialVal: R, isFinal: Boolean) { case None => "(partial: " + initialValue + ")" } } - def getFinalValueInternal() = PartialResult.this.getFinalValueInternal().map(f) + def getFinalValueInternal(): Option[T] = PartialResult.this.getFinalValueInternal().map(f) } } diff --git a/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala index 1cbd684224b7c..9059eb13bb5d8 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala @@ -70,7 +70,7 @@ class CartesianRDD[T: ClassTag, U: ClassTag]( (rdd1.preferredLocations(currSplit.s1) ++ rdd2.preferredLocations(currSplit.s2)).distinct } - override def compute(split: Partition, context: TaskContext) = { + override def compute(split: Partition, context: TaskContext): Iterator[(T, U)] = { val currSplit = split.asInstanceOf[CartesianPartition] for (x <- rdd1.iterator(currSplit.s1, context); y <- rdd2.iterator(currSplit.s2, context)) yield (x, y) diff --git a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala index b073eba8a1574..5117ccfabfcc2 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala @@ -186,7 +186,7 @@ private class PartitionCoalescer(maxPartitions: Int, prev: RDD[_], balanceSlack: override val isEmpty = !it.hasNext // initializes/resets to start iterating from the beginning - def resetIterator() = { + def resetIterator(): Iterator[(String, Partition)] = { val iterators = (0 to 2).map( x => prev.partitions.iterator.flatMap(p => { if (currPrefLocs(p).size > x) Some((currPrefLocs(p)(x), p)) else None @@ -196,10 +196,10 @@ private class PartitionCoalescer(maxPartitions: Int, prev: RDD[_], balanceSlack: } // hasNext() is false iff there are no preferredLocations for any of the partitions of the RDD - def hasNext(): Boolean = { !isEmpty } + override def hasNext: Boolean = { !isEmpty } // return the next preferredLocation of some partition of the RDD - def next(): (String, Partition) = { + override def next(): (String, Partition) = { if (it.hasNext) { it.next() } else { @@ -237,7 +237,7 @@ private class PartitionCoalescer(maxPartitions: Int, prev: RDD[_], balanceSlack: val rotIt = new LocationIterator(prev) // deal with empty case, just create targetLen partition groups with no preferred location - if (!rotIt.hasNext()) { + if (!rotIt.hasNext) { (1 to targetLen).foreach(x => groupArr += PartitionGroup()) return } @@ -343,7 +343,7 @@ private class PartitionCoalescer(maxPartitions: Int, prev: RDD[_], balanceSlack: private case class PartitionGroup(prefLoc: Option[String] = None) { var arr = mutable.ArrayBuffer[Partition]() - def size = arr.size + def size: Int = arr.size } private object PartitionGroup { diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index 486e86ce1bb19..f77abac42b623 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -215,8 +215,7 @@ class HadoopRDD[K, V]( logInfo("Input split: " + split.inputSplit) val jobConf = getJobConf() - val inputMetrics = context.taskMetrics - .getInputMetricsForReadMethod(DataReadMethod.Hadoop) + val inputMetrics = context.taskMetrics.getInputMetricsForReadMethod(DataReadMethod.Hadoop) // Find a function that will return the FileSystem bytes read by this thread. Do this before // creating RecordReader, because RecordReader's constructor might read some bytes @@ -240,7 +239,7 @@ class HadoopRDD[K, V]( val key: K = reader.createKey() val value: V = reader.createValue() - override def getNext() = { + override def getNext(): (K, V) = { try { finished = !reader.next(key, value) } catch { @@ -337,11 +336,11 @@ private[spark] object HadoopRDD extends Logging { * The three methods below are helpers for accessing the local map, a property of the SparkEnv of * the local process. */ - def getCachedMetadata(key: String) = SparkEnv.get.hadoopJobMetadata.get(key) + def getCachedMetadata(key: String): Any = SparkEnv.get.hadoopJobMetadata.get(key) - def containsCachedMetadata(key: String) = SparkEnv.get.hadoopJobMetadata.containsKey(key) + def containsCachedMetadata(key: String): Boolean = SparkEnv.get.hadoopJobMetadata.containsKey(key) - def putCachedMetadata(key: String, value: Any) = + private def putCachedMetadata(key: String, value: Any): Unit = SparkEnv.get.hadoopJobMetadata.put(key, value) /** Add Hadoop configuration specific to a single partition and attempt. */ @@ -371,7 +370,7 @@ private[spark] object HadoopRDD extends Logging { override def getPartitions: Array[Partition] = firstParent[T].partitions - override def compute(split: Partition, context: TaskContext) = { + override def compute(split: Partition, context: TaskContext): Iterator[U] = { val partition = split.asInstanceOf[HadoopPartition] val inputSplit = partition.inputSplit.value f(inputSplit, firstParent[T].iterator(split, context)) diff --git a/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala b/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala index e2267861e79df..0c28f045e46e9 100644 --- a/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala @@ -17,7 +17,7 @@ package org.apache.spark.rdd -import java.sql.{Connection, ResultSet} +import java.sql.{PreparedStatement, Connection, ResultSet} import scala.reflect.ClassTag @@ -28,8 +28,9 @@ import org.apache.spark.util.NextIterator import org.apache.spark.{Logging, Partition, SparkContext, TaskContext} private[spark] class JdbcPartition(idx: Int, val lower: Long, val upper: Long) extends Partition { - override def index = idx + override def index: Int = idx } + // TODO: Expose a jdbcRDD function in SparkContext and mark this as semi-private /** * An RDD that executes an SQL query on a JDBC connection and reads results. @@ -70,7 +71,8 @@ class JdbcRDD[T: ClassTag]( }).toArray } - override def compute(thePart: Partition, context: TaskContext) = new NextIterator[T] { + override def compute(thePart: Partition, context: TaskContext): Iterator[T] = new NextIterator[T] + { context.addTaskCompletionListener{ context => closeIfNeeded() } val part = thePart.asInstanceOf[JdbcPartition] val conn = getConnection() @@ -88,7 +90,7 @@ class JdbcRDD[T: ClassTag]( stmt.setLong(2, part.upper) val rs = stmt.executeQuery() - override def getNext: T = { + override def getNext(): T = { if (rs.next()) { mapRow(rs) } else { diff --git a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala index 4883fb828814c..a838aac6e8d1a 100644 --- a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala @@ -31,6 +31,6 @@ private[spark] class MapPartitionsRDD[U: ClassTag, T: ClassTag]( override def getPartitions: Array[Partition] = firstParent[T].partitions - override def compute(split: Partition, context: TaskContext) = + override def compute(split: Partition, context: TaskContext): Iterator[U] = f(context, split.index, firstParent[T].iterator(split, context)) } diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index 7fb94840df99c..2ab967f4bb313 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -238,7 +238,7 @@ private[spark] object NewHadoopRDD { override def getPartitions: Array[Partition] = firstParent[T].partitions - override def compute(split: Partition, context: TaskContext) = { + override def compute(split: Partition, context: TaskContext): Iterator[U] = { val partition = split.asInstanceOf[NewHadoopPartition] val inputSplit = partition.serializableHadoopSplit.value f(inputSplit, firstParent[T].iterator(split, context)) diff --git a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala index f12d0cffaba34..e2394e28f8d26 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala @@ -98,7 +98,7 @@ private[spark] class ParallelCollectionRDD[T: ClassTag]( slices.indices.map(i => new ParallelCollectionPartition(id, i, slices(i))).toArray } - override def compute(s: Partition, context: TaskContext) = { + override def compute(s: Partition, context: TaskContext): Iterator[T] = { new InterruptibleIterator(context, s.asInstanceOf[ParallelCollectionPartition[T]].iterator) } diff --git a/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala index f781a8d776f2a..a00f4c1cdff91 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala @@ -40,7 +40,7 @@ private[spark] class PruneDependency[T](rdd: RDD[T], @transient partitionFilterF .filter(s => partitionFilterFunc(s.index)).zipWithIndex .map { case(split, idx) => new PartitionPruningRDDPartition(idx, split) : Partition } - override def getParents(partitionId: Int) = { + override def getParents(partitionId: Int): List[Int] = { List(partitions(partitionId).asInstanceOf[PartitionPruningRDDPartition].parentSplit.index) } } @@ -59,8 +59,10 @@ class PartitionPruningRDD[T: ClassTag]( @transient partitionFilterFunc: Int => Boolean) extends RDD[T](prev.context, List(new PruneDependency(prev, partitionFilterFunc))) { - override def compute(split: Partition, context: TaskContext) = firstParent[T].iterator( - split.asInstanceOf[PartitionPruningRDDPartition].parentSplit, context) + override def compute(split: Partition, context: TaskContext): Iterator[T] = { + firstParent[T].iterator( + split.asInstanceOf[PartitionPruningRDDPartition].parentSplit, context) + } override protected def getPartitions: Array[Partition] = getDependencies.head.asInstanceOf[PruneDependency[T]].partitions @@ -74,7 +76,7 @@ object PartitionPruningRDD { * Create a PartitionPruningRDD. This function can be used to create the PartitionPruningRDD * when its type T is not known at compile time. */ - def create[T](rdd: RDD[T], partitionFilterFunc: Int => Boolean) = { + def create[T](rdd: RDD[T], partitionFilterFunc: Int => Boolean): PartitionPruningRDD[T] = { new PartitionPruningRDD[T](rdd, partitionFilterFunc)(rdd.elementClassTag) } } diff --git a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala index ed79032893d33..dc60d48927624 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala @@ -149,10 +149,10 @@ private[spark] class PipedRDD[T: ClassTag]( }.start() // Return an iterator that read lines from the process's stdout - val lines = Source.fromInputStream(proc.getInputStream).getLines + val lines = Source.fromInputStream(proc.getInputStream).getLines() new Iterator[String] { - def next() = lines.next() - def hasNext = { + def next(): String = lines.next() + def hasNext: Boolean = { if (lines.hasNext) { true } else { @@ -162,7 +162,7 @@ private[spark] class PipedRDD[T: ClassTag]( } // cleanup task working directory if used - if (workInTaskDirectory == true) { + if (workInTaskDirectory) { scala.util.control.Exception.ignoring(classOf[IOException]) { Utils.deleteRecursively(new File(taskDirectory)) } diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index a4c74ed03e330..ddbfd5624e741 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -186,7 +186,7 @@ abstract class RDD[T: ClassTag]( } /** Get the RDD's current storage level, or StorageLevel.NONE if none is set. */ - def getStorageLevel = storageLevel + def getStorageLevel: StorageLevel = storageLevel // Our dependencies and partitions will be gotten by calling subclass's methods below, and will // be overwritten when we're checkpointed @@ -746,13 +746,13 @@ abstract class RDD[T: ClassTag]( def zip[U: ClassTag](other: RDD[U]): RDD[(T, U)] = { zipPartitions(other, preservesPartitioning = false) { (thisIter, otherIter) => new Iterator[(T, U)] { - def hasNext = (thisIter.hasNext, otherIter.hasNext) match { + def hasNext: Boolean = (thisIter.hasNext, otherIter.hasNext) match { case (true, true) => true case (false, false) => false case _ => throw new SparkException("Can only zip RDDs with " + "same number of elements in each partition") } - def next = (thisIter.next, otherIter.next) + def next(): (T, U) = (thisIter.next(), otherIter.next()) } } } @@ -868,8 +868,8 @@ abstract class RDD[T: ClassTag]( // Our partitioner knows how to handle T (which, since we have a partitioner, is // really (K, V)) so make a new Partitioner that will de-tuple our fake tuples val p2 = new Partitioner() { - override def numPartitions = p.numPartitions - override def getPartition(k: Any) = p.getPartition(k.asInstanceOf[(Any, _)]._1) + override def numPartitions: Int = p.numPartitions + override def getPartition(k: Any): Int = p.getPartition(k.asInstanceOf[(Any, _)]._1) } // Unfortunately, since we're making a new p2, we'll get ShuffleDependencies // anyway, and when calling .keys, will not have a partitioner set, even though @@ -1394,7 +1394,7 @@ abstract class RDD[T: ClassTag]( } /** The [[org.apache.spark.SparkContext]] that this RDD was created on. */ - def context = sc + def context: SparkContext = sc /** * Private API for changing an RDD's ClassTag. diff --git a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala index d9fe6847254fa..2dc47f95937cb 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala @@ -17,14 +17,12 @@ package org.apache.spark.rdd -import scala.reflect.ClassTag - import org.apache.spark._ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.serializer.Serializer private[spark] class ShuffledRDDPartition(val idx: Int) extends Partition { - override val index = idx + override val index: Int = idx override def hashCode(): Int = idx } diff --git a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala index ed24ea22a661c..c27f435eb9c5a 100644 --- a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala @@ -105,7 +105,7 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag]( seq } } - def integrate(dep: CoGroupSplitDep, op: Product2[K, V] => Unit) = dep match { + def integrate(dep: CoGroupSplitDep, op: Product2[K, V] => Unit): Unit = dep match { case NarrowCoGroupSplitDep(rdd, _, itsSplit) => rdd.iterator(itsSplit, context).asInstanceOf[Iterator[Product2[K, V]]].foreach(op) diff --git a/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala index aece683ff3199..4239e7e22af89 100644 --- a/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala @@ -44,7 +44,7 @@ private[spark] class UnionPartition[T: ClassTag]( var parentPartition: Partition = rdd.partitions(parentRddPartitionIndex) - def preferredLocations() = rdd.preferredLocations(parentPartition) + def preferredLocations(): Seq[String] = rdd.preferredLocations(parentPartition) override val index: Int = idx diff --git a/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala index 95b2dd954e9f4..d0be304762e1f 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala @@ -32,7 +32,7 @@ private[spark] class ZippedPartitionsPartition( override val index: Int = idx var partitionValues = rdds.map(rdd => rdd.partitions(idx)) - def partitions = partitionValues + def partitions: Seq[Partition] = partitionValues @throws(classOf[IOException]) private def writeObject(oos: ObjectOutputStream): Unit = Utils.tryOrIOException { diff --git a/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala index fa83372bb4d11..e0edd7d4ae968 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala @@ -39,8 +39,11 @@ class AccumulableInfo ( } object AccumulableInfo { - def apply(id: Long, name: String, update: Option[String], value: String) = + def apply(id: Long, name: String, update: Option[String], value: String): AccumulableInfo = { new AccumulableInfo(id, name, update, value) + } - def apply(id: Long, name: String, value: String) = new AccumulableInfo(id, name, None, value) + def apply(id: Long, name: String, value: String): AccumulableInfo = { + new AccumulableInfo(id, name, None, value) + } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 8feac6cb6b7a1..b405bd3338e7c 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -946,7 +946,7 @@ class DAGScheduler( val stage = stageIdToStage(task.stageId) - def markStageAsFinished(stage: Stage, errorMessage: Option[String] = None) = { + def markStageAsFinished(stage: Stage, errorMessage: Option[String] = None): Unit = { val serviceTime = stage.latestInfo.submissionTime match { case Some(t) => "%.03f".format((clock.getTimeMillis() - t) / 1000.0) case _ => "Unknown" diff --git a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala index 34fa6d27c3a45..c0d889360ae99 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala @@ -149,47 +149,60 @@ private[spark] class EventLoggingListener( } // Events that do not trigger a flush - override def onStageSubmitted(event: SparkListenerStageSubmitted) = - logEvent(event) - override def onTaskStart(event: SparkListenerTaskStart) = - logEvent(event) - override def onTaskGettingResult(event: SparkListenerTaskGettingResult) = - logEvent(event) - override def onTaskEnd(event: SparkListenerTaskEnd) = - logEvent(event) - override def onEnvironmentUpdate(event: SparkListenerEnvironmentUpdate) = - logEvent(event) + override def onStageSubmitted(event: SparkListenerStageSubmitted): Unit = logEvent(event) + + override def onTaskStart(event: SparkListenerTaskStart): Unit = logEvent(event) + + override def onTaskGettingResult(event: SparkListenerTaskGettingResult): Unit = logEvent(event) + + override def onTaskEnd(event: SparkListenerTaskEnd): Unit = logEvent(event) + + override def onEnvironmentUpdate(event: SparkListenerEnvironmentUpdate): Unit = logEvent(event) // Events that trigger a flush - override def onStageCompleted(event: SparkListenerStageCompleted) = - logEvent(event, flushLogger = true) - override def onJobStart(event: SparkListenerJobStart) = - logEvent(event, flushLogger = true) - override def onJobEnd(event: SparkListenerJobEnd) = + override def onStageCompleted(event: SparkListenerStageCompleted): Unit = { logEvent(event, flushLogger = true) - override def onBlockManagerAdded(event: SparkListenerBlockManagerAdded) = + } + + override def onJobStart(event: SparkListenerJobStart): Unit = logEvent(event, flushLogger = true) + + override def onJobEnd(event: SparkListenerJobEnd): Unit = logEvent(event, flushLogger = true) + + override def onBlockManagerAdded(event: SparkListenerBlockManagerAdded): Unit = { logEvent(event, flushLogger = true) - override def onBlockManagerRemoved(event: SparkListenerBlockManagerRemoved) = + } + + override def onBlockManagerRemoved(event: SparkListenerBlockManagerRemoved): Unit = { logEvent(event, flushLogger = true) - override def onUnpersistRDD(event: SparkListenerUnpersistRDD) = + } + + override def onUnpersistRDD(event: SparkListenerUnpersistRDD): Unit = { logEvent(event, flushLogger = true) - override def onApplicationStart(event: SparkListenerApplicationStart) = + } + + override def onApplicationStart(event: SparkListenerApplicationStart): Unit = { logEvent(event, flushLogger = true) - override def onApplicationEnd(event: SparkListenerApplicationEnd) = + } + + override def onApplicationEnd(event: SparkListenerApplicationEnd): Unit = { logEvent(event, flushLogger = true) - override def onExecutorAdded(event: SparkListenerExecutorAdded) = + } + override def onExecutorAdded(event: SparkListenerExecutorAdded): Unit = { logEvent(event, flushLogger = true) - override def onExecutorRemoved(event: SparkListenerExecutorRemoved) = + } + + override def onExecutorRemoved(event: SparkListenerExecutorRemoved): Unit = { logEvent(event, flushLogger = true) + } // No-op because logging every update would be overkill - override def onExecutorMetricsUpdate(event: SparkListenerExecutorMetricsUpdate) { } + override def onExecutorMetricsUpdate(event: SparkListenerExecutorMetricsUpdate): Unit = { } /** * Stop logging events. The event log file will be renamed so that it loses the * ".inprogress" suffix. */ - def stop() = { + def stop(): Unit = { writer.foreach(_.close()) val target = new Path(logPath) diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala index 8aa528ac573d0..e55b76c36cc5f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala @@ -57,7 +57,7 @@ class JobLogger(val user: String, val logDirName: String) extends SparkListener private val stageIdToJobId = new HashMap[Int, Int] private val jobIdToStageIds = new HashMap[Int, Seq[Int]] private val dateFormat = new ThreadLocal[SimpleDateFormat]() { - override def initialValue() = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss") + override def initialValue(): SimpleDateFormat = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss") } createLogDir() diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala b/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala index 29879b374b801..382b09422a4a0 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala @@ -34,7 +34,7 @@ private[spark] class JobWaiter[T]( @volatile private var _jobFinished = totalTasks == 0 - def jobFinished = _jobFinished + def jobFinished: Boolean = _jobFinished // If the job is finished, this will be its result. In the case of 0 task jobs (e.g. zero // partition RDDs), we set the jobResult directly to JobSucceeded. diff --git a/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala index 759df023a6dcf..a3caa9f000c89 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala @@ -160,7 +160,7 @@ private[spark] object OutputCommitCoordinator { class OutputCommitCoordinatorActor(outputCommitCoordinator: OutputCommitCoordinator) extends Actor with ActorLogReceive with Logging { - override def receiveWithLogging = { + override def receiveWithLogging: PartialFunction[Any, Unit] = { case AskPermissionToCommitOutput(stage, partition, taskAttempt) => sender ! outputCommitCoordinator.handleAskPermissionToCommit(stage, partition, taskAttempt) case StopCoordinator => diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala index 4a9ff918afe25..e074ce6ebff0b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala @@ -64,5 +64,5 @@ private[spark] class ResultTask[T, U]( // This is only callable on the driver side. override def preferredLocations: Seq[TaskLocation] = preferredLocs - override def toString = "ResultTask(" + stageId + ", " + partitionId + ")" + override def toString: String = "ResultTask(" + stageId + ", " + partitionId + ")" } diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala index 79709089c0da4..fd0d484b45460 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala @@ -47,7 +47,7 @@ private[spark] class ShuffleMapTask( /** A constructor used only in test suites. This does not require passing in an RDD. */ def this(partitionId: Int) { - this(0, null, new Partition { override def index = 0 }, null) + this(0, null, new Partition { override def index: Int = 0 }, null) } @transient private val preferredLocs: Seq[TaskLocation] = { @@ -83,5 +83,5 @@ private[spark] class ShuffleMapTask( override def preferredLocations: Seq[TaskLocation] = preferredLocs - override def toString = "ShuffleMapTask(%d, %d)".format(stageId, partitionId) + override def toString: String = "ShuffleMapTask(%d, %d)".format(stageId, partitionId) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala index 52720d48ca67f..b711ff209af94 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala @@ -300,7 +300,7 @@ private[spark] object StatsReportListener extends Logging { } def showDistribution(heading: String, dOpt: Option[Distribution], format:String) { - def f(d: Double) = format.format(d) + def f(d: Double): String = format.format(d) showDistribution(heading, dOpt, f _) } @@ -346,7 +346,7 @@ private[spark] object StatsReportListener extends Logging { /** * Reformat a time interval in milliseconds to a prettier format for output */ - def millisToString(ms: Long) = { + def millisToString(ms: Long): String = { val (size, units) = if (ms > hours) { (ms.toDouble / hours, "hours") diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala index cc13f57a49b89..4cbc6e84a6bdd 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala @@ -133,7 +133,7 @@ private[spark] class Stage( def attemptId: Int = nextAttemptId - override def toString = "Stage " + id + override def toString: String = "Stage " + id override def hashCode(): Int = id diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskLocation.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskLocation.scala index 10c685f29d3ac..da07ce2c6ea49 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskLocation.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskLocation.scala @@ -29,23 +29,22 @@ private[spark] sealed trait TaskLocation { /** * A location that includes both a host and an executor id on that host. */ -private [spark] case class ExecutorCacheTaskLocation(override val host: String, - val executorId: String) extends TaskLocation { -} +private [spark] +case class ExecutorCacheTaskLocation(override val host: String, executorId: String) + extends TaskLocation /** * A location on a host. */ private [spark] case class HostTaskLocation(override val host: String) extends TaskLocation { - override def toString = host + override def toString: String = host } /** * A location on a host that is cached by HDFS. */ -private [spark] case class HDFSCacheTaskLocation(override val host: String) - extends TaskLocation { - override def toString = TaskLocation.inMemoryLocationTag + host +private [spark] case class HDFSCacheTaskLocation(override val host: String) extends TaskLocation { + override def toString: String = TaskLocation.inMemoryLocationTag + host } private[spark] object TaskLocation { @@ -54,14 +53,16 @@ private[spark] object TaskLocation { // confusion. See RFC 952 and RFC 1123 for information about the format of hostnames. val inMemoryLocationTag = "hdfs_cache_" - def apply(host: String, executorId: String) = new ExecutorCacheTaskLocation(host, executorId) + def apply(host: String, executorId: String): TaskLocation = { + new ExecutorCacheTaskLocation(host, executorId) + } /** * Create a TaskLocation from a string returned by getPreferredLocations. * These strings have the form [hostname] or hdfs_cache_[hostname], depending on whether the * location is cached. */ - def apply(str: String) = { + def apply(str: String): TaskLocation = { val hstr = str.stripPrefix(inMemoryLocationTag) if (hstr.equals(str)) { new HostTaskLocation(str) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index f33fd4450b2a6..076b36e86c0ce 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -373,17 +373,17 @@ private[spark] class TaskSchedulerImpl( } def handleSuccessfulTask( - taskSetManager: TaskSetManager, - tid: Long, - taskResult: DirectTaskResult[_]) = synchronized { + taskSetManager: TaskSetManager, + tid: Long, + taskResult: DirectTaskResult[_]): Unit = synchronized { taskSetManager.handleSuccessfulTask(tid, taskResult) } def handleFailedTask( - taskSetManager: TaskSetManager, - tid: Long, - taskState: TaskState, - reason: TaskEndReason) = synchronized { + taskSetManager: TaskSetManager, + tid: Long, + taskState: TaskState, + reason: TaskEndReason): Unit = synchronized { taskSetManager.handleFailedTask(tid, taskState, reason) if (!taskSetManager.isZombie && taskState != TaskState.KILLED) { // Need to revive offers again now that the task set manager state has been updated to @@ -423,7 +423,7 @@ private[spark] class TaskSchedulerImpl( starvationTimer.cancel() } - override def defaultParallelism() = backend.defaultParallelism() + override def defaultParallelism(): Int = backend.defaultParallelism() // Check for speculatable tasks in all our active jobs. def checkSpeculatableTasks() { diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 529237f0d35dc..d509881c74fef 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -20,6 +20,7 @@ package org.apache.spark.scheduler import java.io.NotSerializableException import java.nio.ByteBuffer import java.util.Arrays +import java.util.concurrent.ConcurrentLinkedQueue import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashMap @@ -29,6 +30,7 @@ import scala.util.control.NonFatal import org.apache.spark._ import org.apache.spark.executor.TaskMetrics +import org.apache.spark.scheduler.SchedulingMode._ import org.apache.spark.TaskState.TaskState import org.apache.spark.util.{Clock, SystemClock, Utils} @@ -97,7 +99,8 @@ private[spark] class TaskSetManager( var calculatedTasks = 0 val runningTasksSet = new HashSet[Long] - override def runningTasks = runningTasksSet.size + + override def runningTasks: Int = runningTasksSet.size // True once no more tasks should be launched for this task set manager. TaskSetManagers enter // the zombie state once at least one attempt of each task has completed successfully, or if the @@ -168,9 +171,9 @@ private[spark] class TaskSetManager( var currentLocalityIndex = 0 // Index of our current locality level in validLocalityLevels var lastLaunchTime = clock.getTimeMillis() // Time we last launched a task at this level - override def schedulableQueue = null + override def schedulableQueue: ConcurrentLinkedQueue[Schedulable] = null - override def schedulingMode = SchedulingMode.NONE + override def schedulingMode: SchedulingMode = SchedulingMode.NONE var emittedTaskSizeWarning = false @@ -585,7 +588,7 @@ private[spark] class TaskSetManager( /** * Marks the task as getting result and notifies the DAG Scheduler */ - def handleTaskGettingResult(tid: Long) = { + def handleTaskGettingResult(tid: Long): Unit = { val info = taskInfos(tid) info.markGettingResult() sched.dagScheduler.taskGettingResult(info) @@ -612,7 +615,7 @@ private[spark] class TaskSetManager( /** * Marks the task as successful and notifies the DAGScheduler that a task has ended. */ - def handleSuccessfulTask(tid: Long, result: DirectTaskResult[_]) = { + def handleSuccessfulTask(tid: Long, result: DirectTaskResult[_]): Unit = { val info = taskInfos(tid) val index = info.index info.markSuccessful() diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 87ebf31139ce9..5d258d9da4d1a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -85,7 +85,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste context.system.scheduler.schedule(0.millis, reviveInterval.millis, self, ReviveOffers) } - def receiveWithLogging = { + def receiveWithLogging: PartialFunction[Any, Unit] = { case RegisterExecutor(executorId, hostPort, cores, logUrls) => Utils.checkHostPort(hostPort, "Host port expected " + hostPort) if (executorDataMap.contains(executorId)) { diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala index f14aaeea0a25c..5a38ad9f2b12c 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala @@ -109,7 +109,7 @@ private[spark] abstract class YarnSchedulerBackend( context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) } - override def receive = { + override def receive: PartialFunction[Any, Unit] = { case RegisterClusterManager => logInfo(s"ApplicationMaster registered as $sender") amActor = Some(sender) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MemoryUtils.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MemoryUtils.scala index aa3ec0f8cfb9c..8df4f3b554c41 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MemoryUtils.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MemoryUtils.scala @@ -24,7 +24,7 @@ private[spark] object MemoryUtils { val OVERHEAD_FRACTION = 0.10 val OVERHEAD_MINIMUM = 384 - def calculateTotalMemory(sc: SparkContext) = { + def calculateTotalMemory(sc: SparkContext): Int = { sc.conf.getInt("spark.mesos.executor.memoryOverhead", math.max(OVERHEAD_FRACTION * sc.executorMemory, OVERHEAD_MINIMUM).toInt) + sc.executorMemory } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala index 06bb527522141..b381436839227 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala @@ -387,7 +387,7 @@ private[spark] class MesosSchedulerBackend( } // TODO: query Mesos for number of cores - override def defaultParallelism() = sc.conf.getInt("spark.default.parallelism", 8) + override def defaultParallelism(): Int = sc.conf.getInt("spark.default.parallelism", 8) override def applicationId(): String = Option(appId).getOrElse { diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala index d95426d918e19..eb3f999b5b375 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala @@ -59,7 +59,7 @@ private[spark] class LocalActor( private val executor = new Executor( localExecutorId, localExecutorHostname, SparkEnv.get, isLocal = true) - override def receiveWithLogging = { + override def receiveWithLogging: PartialFunction[Any, Unit] = { case ReviveOffers => reviveOffers() @@ -117,7 +117,7 @@ private[spark] class LocalBackend(scheduler: TaskSchedulerImpl, val totalCores: localActor ! ReviveOffers } - override def defaultParallelism() = + override def defaultParallelism(): Int = scheduler.conf.getInt("spark.default.parallelism", totalCores) override def killTask(taskId: Long, executorId: String, interruptThread: Boolean) { diff --git a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala index 1baa0e009f3ae..dfbde7c8a1b0d 100644 --- a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala @@ -59,9 +59,10 @@ private[spark] class JavaSerializationStream( } private[spark] class JavaDeserializationStream(in: InputStream, loader: ClassLoader) -extends DeserializationStream { + extends DeserializationStream { + private val objIn = new ObjectInputStream(in) { - override def resolveClass(desc: ObjectStreamClass) = + override def resolveClass(desc: ObjectStreamClass): Class[_] = Class.forName(desc.getName, false, loader) } diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index dc7aa99738c17..f83bcaa5cc09e 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -60,7 +60,7 @@ class KryoSerializer(conf: SparkConf) .split(',') .filter(!_.isEmpty) - def newKryoOutput() = new KryoOutput(bufferSize, math.max(bufferSize, maxBufferSize)) + def newKryoOutput(): KryoOutput = new KryoOutput(bufferSize, math.max(bufferSize, maxBufferSize)) def newKryo(): Kryo = { val instantiator = new EmptyScalaKryoInstantiator diff --git a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala index 7de2f9cbb2866..660df00bc32f5 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala @@ -106,7 +106,7 @@ class FileShuffleBlockManager(conf: SparkConf) * when the writers are closed successfully */ def forMapTask(shuffleId: Int, mapId: Int, numBuckets: Int, serializer: Serializer, - writeMetrics: ShuffleWriteMetrics) = { + writeMetrics: ShuffleWriteMetrics): ShuffleWriterGroup = { new ShuffleWriterGroup { shuffleStates.putIfAbsent(shuffleId, new ShuffleState(numBuckets)) private val shuffleState = shuffleStates(shuffleId) @@ -268,7 +268,7 @@ object FileShuffleBlockManager { new PrimitiveVector[Long]() } - def apply(bucketId: Int) = files(bucketId) + def apply(bucketId: Int): File = files(bucketId) def recordMapOutput(mapId: Int, offsets: Array[Long], lengths: Array[Long]) { assert(offsets.length == lengths.length) diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala index b292587d37028..87fd161e06c85 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala @@ -80,7 +80,7 @@ class IndexShuffleBlockManager(conf: SparkConf) extends ShuffleBlockManager { * end of the output file. This will be used by getBlockLocation to figure out where each block * begins and ends. * */ - def writeIndexFile(shuffleId: Int, mapId: Int, lengths: Array[Long]) = { + def writeIndexFile(shuffleId: Int, mapId: Int, lengths: Array[Long]): Unit = { val indexFile = getIndexFile(shuffleId, mapId) val out = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(indexFile))) try { @@ -121,5 +121,5 @@ class IndexShuffleBlockManager(conf: SparkConf) extends ShuffleBlockManager { } } - override def stop() = {} + override def stop(): Unit = {} } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockId.scala b/core/src/main/scala/org/apache/spark/storage/BlockId.scala index 1f012941c85ab..c186fd360fef6 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockId.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockId.scala @@ -35,13 +35,13 @@ sealed abstract class BlockId { def name: String // convenience methods - def asRDDId = if (isRDD) Some(asInstanceOf[RDDBlockId]) else None - def isRDD = isInstanceOf[RDDBlockId] - def isShuffle = isInstanceOf[ShuffleBlockId] - def isBroadcast = isInstanceOf[BroadcastBlockId] + def asRDDId: Option[RDDBlockId] = if (isRDD) Some(asInstanceOf[RDDBlockId]) else None + def isRDD: Boolean = isInstanceOf[RDDBlockId] + def isShuffle: Boolean = isInstanceOf[ShuffleBlockId] + def isBroadcast: Boolean = isInstanceOf[BroadcastBlockId] - override def toString = name - override def hashCode = name.hashCode + override def toString: String = name + override def hashCode: Int = name.hashCode override def equals(other: Any): Boolean = other match { case o: BlockId => getClass == o.getClass && name.equals(o.name) case _ => false @@ -50,54 +50,54 @@ sealed abstract class BlockId { @DeveloperApi case class RDDBlockId(rddId: Int, splitIndex: Int) extends BlockId { - def name = "rdd_" + rddId + "_" + splitIndex + override def name: String = "rdd_" + rddId + "_" + splitIndex } // Format of the shuffle block ids (including data and index) should be kept in sync with // org.apache.spark.network.shuffle.StandaloneShuffleBlockManager#getBlockData(). @DeveloperApi case class ShuffleBlockId(shuffleId: Int, mapId: Int, reduceId: Int) extends BlockId { - def name = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId + override def name: String = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId } @DeveloperApi case class ShuffleDataBlockId(shuffleId: Int, mapId: Int, reduceId: Int) extends BlockId { - def name = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId + ".data" + override def name: String = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId + ".data" } @DeveloperApi case class ShuffleIndexBlockId(shuffleId: Int, mapId: Int, reduceId: Int) extends BlockId { - def name = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId + ".index" + override def name: String = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId + ".index" } @DeveloperApi case class BroadcastBlockId(broadcastId: Long, field: String = "") extends BlockId { - def name = "broadcast_" + broadcastId + (if (field == "") "" else "_" + field) + override def name: String = "broadcast_" + broadcastId + (if (field == "") "" else "_" + field) } @DeveloperApi case class TaskResultBlockId(taskId: Long) extends BlockId { - def name = "taskresult_" + taskId + override def name: String = "taskresult_" + taskId } @DeveloperApi case class StreamBlockId(streamId: Int, uniqueId: Long) extends BlockId { - def name = "input-" + streamId + "-" + uniqueId + override def name: String = "input-" + streamId + "-" + uniqueId } /** Id associated with temporary local data managed as blocks. Not serializable. */ private[spark] case class TempLocalBlockId(id: UUID) extends BlockId { - def name = "temp_local_" + id + override def name: String = "temp_local_" + id } /** Id associated with temporary shuffle data managed as blocks. Not serializable. */ private[spark] case class TempShuffleBlockId(id: UUID) extends BlockId { - def name = "temp_shuffle_" + id + override def name: String = "temp_shuffle_" + id } // Intended only for testing purposes private[spark] case class TestBlockId(id: String) extends BlockId { - def name = "test_" + id + override def name: String = "test_" + id } @DeveloperApi @@ -112,7 +112,7 @@ object BlockId { val TEST = "test_(.*)".r /** Converts a BlockId "name" String back into a BlockId. */ - def apply(id: String) = id match { + def apply(id: String): BlockId = id match { case RDD(rddId, splitIndex) => RDDBlockId(rddId.toInt, splitIndex.toInt) case SHUFFLE(shuffleId, mapId, reduceId) => diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala index b177a59c721df..a6f1ebf325a7c 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala @@ -77,11 +77,11 @@ class BlockManagerId private ( @throws(classOf[IOException]) private def readResolve(): Object = BlockManagerId.getCachedBlockManagerId(this) - override def toString = s"BlockManagerId($executorId, $host, $port)" + override def toString: String = s"BlockManagerId($executorId, $host, $port)" override def hashCode: Int = (executorId.hashCode * 41 + host.hashCode) * 41 + port - override def equals(that: Any) = that match { + override def equals(that: Any): Boolean = that match { case id: BlockManagerId => executorId == id.executorId && port == id.port && host == id.host case _ => @@ -100,10 +100,10 @@ private[spark] object BlockManagerId { * @param port Port of the block manager. * @return A new [[org.apache.spark.storage.BlockManagerId]]. */ - def apply(execId: String, host: String, port: Int) = + def apply(execId: String, host: String, port: Int): BlockManagerId = getCachedBlockManagerId(new BlockManagerId(execId, host, port)) - def apply(in: ObjectInput) = { + def apply(in: ObjectInput): BlockManagerId = { val obj = new BlockManagerId() obj.readExternal(in) getCachedBlockManagerId(obj) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala index 654796f23c96e..061964826f08b 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala @@ -79,7 +79,7 @@ class BlockManagerMaster( * Check if block manager master has a block. Note that this can be used to check for only * those blocks that are reported to block manager master. */ - def contains(blockId: BlockId) = { + def contains(blockId: BlockId): Boolean = { !getLocations(blockId).isEmpty } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala index 787b0f96bec32..5b5328016124e 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala @@ -52,7 +52,7 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus private val akkaTimeout = AkkaUtils.askTimeout(conf) - override def receiveWithLogging = { + override def receiveWithLogging: PartialFunction[Any, Unit] = { case RegisterBlockManager(blockManagerId, maxMemSize, slaveActor) => register(blockManagerId, maxMemSize, slaveActor) sender ! true @@ -421,7 +421,7 @@ private[spark] class BlockManagerInfo( // Mapping from block id to its status. private val _blocks = new JHashMap[BlockId, BlockStatus] - def getStatus(blockId: BlockId) = Option(_blocks.get(blockId)) + def getStatus(blockId: BlockId): Option[BlockStatus] = Option(_blocks.get(blockId)) def updateLastSeenMs() { _lastSeenMs = System.currentTimeMillis() diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala index 8462871e798a5..52fb896c4e21f 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala @@ -38,7 +38,7 @@ class BlockManagerSlaveActor( import context.dispatcher // Operations that involve removing blocks may be slow and should be done asynchronously - override def receiveWithLogging = { + override def receiveWithLogging: PartialFunction[Any, Unit] = { case RemoveBlock(blockId) => doAsync[Boolean]("removing block " + blockId, sender) { blockManager.removeBlock(blockId) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala index 81164178b9e8e..f703e50b6b0ac 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala @@ -82,11 +82,13 @@ private[spark] class DiskBlockObjectWriter( { /** Intercepts write calls and tracks total time spent writing. Not thread safe. */ private class TimeTrackingOutputStream(out: OutputStream) extends OutputStream { - def write(i: Int): Unit = callWithTiming(out.write(i)) - override def write(b: Array[Byte]) = callWithTiming(out.write(b)) - override def write(b: Array[Byte], off: Int, len: Int) = callWithTiming(out.write(b, off, len)) - override def close() = out.close() - override def flush() = out.flush() + override def write(i: Int): Unit = callWithTiming(out.write(i)) + override def write(b: Array[Byte]): Unit = callWithTiming(out.write(b)) + override def write(b: Array[Byte], off: Int, len: Int): Unit = { + callWithTiming(out.write(b, off, len)) + } + override def close(): Unit = out.close() + override def flush(): Unit = out.flush() } /** The file channel, used for repositioning / truncating the file. */ @@ -141,8 +143,9 @@ private[spark] class DiskBlockObjectWriter( if (syncWrites) { // Force outstanding writes to disk and track how long it takes objOut.flush() - def sync = fos.getFD.sync() - callWithTiming(sync) + callWithTiming { + fos.getFD.sync() + } } objOut.close() diff --git a/core/src/main/scala/org/apache/spark/storage/FileSegment.scala b/core/src/main/scala/org/apache/spark/storage/FileSegment.scala index 132502b75f8cd..95e2d688d9b17 100644 --- a/core/src/main/scala/org/apache/spark/storage/FileSegment.scala +++ b/core/src/main/scala/org/apache/spark/storage/FileSegment.scala @@ -24,5 +24,7 @@ import java.io.File * based off an offset and a length. */ private[spark] class FileSegment(val file: File, val offset: Long, val length: Long) { - override def toString = "(name=%s, offset=%d, length=%d)".format(file.getName, offset, length) + override def toString: String = { + "(name=%s, offset=%d, length=%d)".format(file.getName, offset, length) + } } diff --git a/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala b/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala index 120c327a7e580..0186eb30a1905 100644 --- a/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala +++ b/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala @@ -36,7 +36,7 @@ class RDDInfo( def isCached: Boolean = (memSize + diskSize + tachyonSize > 0) && numCachedPartitions > 0 - override def toString = { + override def toString: String = { import Utils.bytesToString ("RDD \"%s\" (%d) StorageLevel: %s; CachedPartitions: %d; TotalPartitions: %d; " + "MemorySize: %s; TachyonSize: %s; DiskSize: %s").format( @@ -44,7 +44,7 @@ class RDDInfo( bytesToString(memSize), bytesToString(tachyonSize), bytesToString(diskSize)) } - override def compare(that: RDDInfo) = { + override def compare(that: RDDInfo): Int = { this.id - that.id } } diff --git a/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala b/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala index e5e1cf5a69a19..134abea866218 100644 --- a/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala +++ b/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala @@ -50,11 +50,11 @@ class StorageLevel private( def this() = this(false, true, false, false) // For deserialization - def useDisk = _useDisk - def useMemory = _useMemory - def useOffHeap = _useOffHeap - def deserialized = _deserialized - def replication = _replication + def useDisk: Boolean = _useDisk + def useMemory: Boolean = _useMemory + def useOffHeap: Boolean = _useOffHeap + def deserialized: Boolean = _deserialized + def replication: Int = _replication assert(replication < 40, "Replication restricted to be less than 40 for calculating hash codes") @@ -80,7 +80,7 @@ class StorageLevel private( false } - def isValid = (useMemory || useDisk || useOffHeap) && (replication > 0) + def isValid: Boolean = (useMemory || useDisk || useOffHeap) && (replication > 0) def toInt: Int = { var ret = 0 @@ -183,7 +183,7 @@ object StorageLevel { useMemory: Boolean, useOffHeap: Boolean, deserialized: Boolean, - replication: Int) = { + replication: Int): StorageLevel = { getCachedStorageLevel( new StorageLevel(useDisk, useMemory, useOffHeap, deserialized, replication)) } @@ -197,7 +197,7 @@ object StorageLevel { useDisk: Boolean, useMemory: Boolean, deserialized: Boolean, - replication: Int = 1) = { + replication: Int = 1): StorageLevel = { getCachedStorageLevel(new StorageLevel(useDisk, useMemory, false, deserialized, replication)) } diff --git a/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala b/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala index def49e80a3605..7d75929b96f75 100644 --- a/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala +++ b/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala @@ -19,7 +19,6 @@ package org.apache.spark.storage import scala.collection.mutable -import org.apache.spark.SparkContext import org.apache.spark.annotation.DeveloperApi import org.apache.spark.scheduler._ @@ -32,7 +31,7 @@ class StorageStatusListener extends SparkListener { // This maintains only blocks that are cached (i.e. storage level is not StorageLevel.NONE) private[storage] val executorIdToStorageStatus = mutable.Map[String, StorageStatus]() - def storageStatusList = executorIdToStorageStatus.values.toSeq + def storageStatusList: Seq[StorageStatus] = executorIdToStorageStatus.values.toSeq /** Update storage status list to reflect updated block statuses */ private def updateStorageStatus(execId: String, updatedBlocks: Seq[(BlockId, BlockStatus)]) { @@ -56,7 +55,7 @@ class StorageStatusListener extends SparkListener { } } - override def onTaskEnd(taskEnd: SparkListenerTaskEnd) = synchronized { + override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = synchronized { val info = taskEnd.taskInfo val metrics = taskEnd.taskMetrics if (info != null && metrics != null) { @@ -67,7 +66,7 @@ class StorageStatusListener extends SparkListener { } } - override def onUnpersistRDD(unpersistRDD: SparkListenerUnpersistRDD) = synchronized { + override def onUnpersistRDD(unpersistRDD: SparkListenerUnpersistRDD): Unit = synchronized { updateStorageStatus(unpersistRDD.rddId) } diff --git a/core/src/main/scala/org/apache/spark/storage/TachyonFileSegment.scala b/core/src/main/scala/org/apache/spark/storage/TachyonFileSegment.scala index b86abbda1d3e7..65fa81704c365 100644 --- a/core/src/main/scala/org/apache/spark/storage/TachyonFileSegment.scala +++ b/core/src/main/scala/org/apache/spark/storage/TachyonFileSegment.scala @@ -24,5 +24,7 @@ import tachyon.client.TachyonFile * a length. */ private[spark] class TachyonFileSegment(val file: TachyonFile, val offset: Long, val length: Long) { - override def toString = "(name=%s, offset=%d, length=%d)".format(file.getPath(), offset, length) + override def toString: String = { + "(name=%s, offset=%d, length=%d)".format(file.getPath(), offset, length) + } } diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala index 0c24ad2760e08..adfa6bbada256 100644 --- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala @@ -60,7 +60,7 @@ private[spark] class SparkUI private ( } initialize() - def getAppName = appName + def getAppName: String = appName /** Set the app name for this UI. */ def setAppName(name: String) { diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index b5022fe853c49..f07864141a21c 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -149,9 +149,11 @@ private[spark] object UIUtils extends Logging { } } - def prependBaseUri(basePath: String = "", resource: String = "") = uiRoot + basePath + resource + def prependBaseUri(basePath: String = "", resource: String = ""): String = { + uiRoot + basePath + resource + } - def commonHeaderNodes = { + def commonHeaderNodes: Seq[Node] = { diff --git a/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala b/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala index fc1844600f1cb..19ac7a826e306 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala @@ -51,7 +51,7 @@ private[spark] object UIWorkloadGenerator { val nJobSet = args(2).toInt val sc = new SparkContext(conf) - def setProperties(s: String) = { + def setProperties(s: String): Unit = { if(schedulingMode == SchedulingMode.FAIR) { sc.setLocalProperty("spark.scheduler.pool", s) } @@ -59,7 +59,7 @@ private[spark] object UIWorkloadGenerator { } val baseData = sc.makeRDD(1 to NUM_PARTITIONS * 10, NUM_PARTITIONS) - def nextFloat() = new Random().nextFloat() + def nextFloat(): Float = new Random().nextFloat() val jobs = Seq[(String, () => Long)]( ("Count", baseData.count), diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala index 3afd7ef07d7c9..69053fe44d7e4 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala @@ -22,7 +22,7 @@ import scala.collection.mutable.HashMap import org.apache.spark.ExceptionFailure import org.apache.spark.annotation.DeveloperApi import org.apache.spark.scheduler._ -import org.apache.spark.storage.StorageStatusListener +import org.apache.spark.storage.{StorageStatus, StorageStatusListener} import org.apache.spark.ui.{SparkUI, SparkUITab} private[ui] class ExecutorsTab(parent: SparkUI) extends SparkUITab(parent, "executors") { @@ -55,19 +55,19 @@ class ExecutorsListener(storageStatusListener: StorageStatusListener) extends Sp val executorToShuffleWrite = HashMap[String, Long]() val executorToLogUrls = HashMap[String, Map[String, String]]() - def storageStatusList = storageStatusListener.storageStatusList + def storageStatusList: Seq[StorageStatus] = storageStatusListener.storageStatusList - override def onExecutorAdded(executorAdded: SparkListenerExecutorAdded) = synchronized { + override def onExecutorAdded(executorAdded: SparkListenerExecutorAdded): Unit = synchronized { val eid = executorAdded.executorId executorToLogUrls(eid) = executorAdded.executorInfo.logUrlMap } - override def onTaskStart(taskStart: SparkListenerTaskStart) = synchronized { + override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = synchronized { val eid = taskStart.taskInfo.executorId executorToTasksActive(eid) = executorToTasksActive.getOrElse(eid, 0) + 1 } - override def onTaskEnd(taskEnd: SparkListenerTaskEnd) = synchronized { + override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = synchronized { val info = taskEnd.taskInfo if (info != null) { val eid = info.executorId diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala index 937d95a934b59..949e80d30f5eb 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala @@ -73,7 +73,7 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { // Misc: val executorIdToBlockManagerId = HashMap[ExecutorId, BlockManagerId]() - def blockManagerIds = executorIdToBlockManagerId.values.toSeq + def blockManagerIds: Seq[BlockManagerId] = executorIdToBlockManagerId.values.toSeq var schedulingMode: Option[SchedulingMode] = None @@ -146,7 +146,7 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { } } - override def onJobStart(jobStart: SparkListenerJobStart) = synchronized { + override def onJobStart(jobStart: SparkListenerJobStart): Unit = synchronized { val jobGroup = for ( props <- Option(jobStart.properties); group <- Option(props.getProperty(SparkContext.SPARK_JOB_GROUP_ID)) @@ -182,7 +182,7 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { } } - override def onJobEnd(jobEnd: SparkListenerJobEnd) = synchronized { + override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = synchronized { val jobData = activeJobs.remove(jobEnd.jobId).getOrElse { logWarning(s"Job completed for unknown job ${jobEnd.jobId}") new JobUIData(jobId = jobEnd.jobId) @@ -219,7 +219,7 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { } } - override def onStageCompleted(stageCompleted: SparkListenerStageCompleted) = synchronized { + override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = synchronized { val stage = stageCompleted.stageInfo stageIdToInfo(stage.stageId) = stage val stageData = stageIdToData.getOrElseUpdate((stage.stageId, stage.attemptId), { @@ -260,7 +260,7 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { } /** For FIFO, all stages are contained by "default" pool but "default" pool here is meaningless */ - override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) = synchronized { + override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted): Unit = synchronized { val stage = stageSubmitted.stageInfo activeStages(stage.stageId) = stage pendingStages.remove(stage.stageId) @@ -288,7 +288,7 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { } } - override def onTaskStart(taskStart: SparkListenerTaskStart) = synchronized { + override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = synchronized { val taskInfo = taskStart.taskInfo if (taskInfo != null) { val stageData = stageIdToData.getOrElseUpdate((taskStart.stageId, taskStart.stageAttemptId), { @@ -312,7 +312,7 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { // stageToTaskInfos already has the updated status. } - override def onTaskEnd(taskEnd: SparkListenerTaskEnd) = synchronized { + override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = synchronized { val info = taskEnd.taskInfo // If stage attempt id is -1, it means the DAGScheduler had no idea which attempt this task // completion event is for. Let's just drop it here. This means we might have some speculation diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala index b2bbfdee56946..7ffcf291b5cc6 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala @@ -24,7 +24,7 @@ import org.apache.spark.ui.{SparkUI, SparkUITab} private[ui] class JobsTab(parent: SparkUI) extends SparkUITab(parent, "jobs") { val sc = parent.sc val killEnabled = parent.killEnabled - def isFairScheduler = listener.schedulingMode.exists(_ == SchedulingMode.FAIR) + def isFairScheduler: Boolean = listener.schedulingMode.exists(_ == SchedulingMode.FAIR) val listener = parent.jobProgressListener attachPage(new AllJobsPage(this)) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 110f8780a9a12..e03442894c5cc 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -20,7 +20,7 @@ package org.apache.spark.ui.jobs import java.util.Date import javax.servlet.http.HttpServletRequest -import scala.xml.{Node, Unparsed} +import scala.xml.{Elem, Node, Unparsed} import org.apache.commons.lang3.StringEscapeUtils @@ -170,7 +170,8 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
val accumulableHeaders: Seq[String] = Seq("Accumulable", "Value") - def accumulableRow(acc: AccumulableInfo) = {acc.name}{acc.value} + def accumulableRow(acc: AccumulableInfo): Elem = + {acc.name}{acc.value} val accumulableTable = UIUtils.listingTable(accumulableHeaders, accumulableRow, accumulables.values.toSeq) @@ -293,10 +294,11 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { val schedulerDelayQuantiles = schedulerDelayTitle +: getFormattedTimeQuantiles(schedulerDelays) - def getFormattedSizeQuantiles(data: Seq[Double]) = + def getFormattedSizeQuantiles(data: Seq[Double]): Seq[Elem] = getDistributionQuantiles(data).map(d => {Utils.bytesToString(d.toLong)}) - def getFormattedSizeQuantilesWithRecords(data: Seq[Double], records: Seq[Double]) = { + def getFormattedSizeQuantilesWithRecords(data: Seq[Double], records: Seq[Double]) + : Seq[Elem] = { val recordDist = getDistributionQuantiles(records).iterator getDistributionQuantiles(data).map(d => {s"${Utils.bytesToString(d.toLong)} / ${recordDist.next().toLong}"} diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala index 937261de00e3a..1bd2d87e00796 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala @@ -32,10 +32,10 @@ private[ui] class StagesTab(parent: SparkUI) extends SparkUITab(parent, "stages" attachPage(new StagePage(this)) attachPage(new PoolPage(this)) - def isFairScheduler = listener.schedulingMode.exists(_ == SchedulingMode.FAIR) + def isFairScheduler: Boolean = listener.schedulingMode.exists(_ == SchedulingMode.FAIR) - def handleKillRequest(request: HttpServletRequest) = { - if ((killEnabled) && (parent.securityManager.checkModifyPermissions(request.getRemoteUser))) { + def handleKillRequest(request: HttpServletRequest): Unit = { + if (killEnabled && parent.securityManager.checkModifyPermissions(request.getRemoteUser)) { val killFlag = Option(request.getParameter("terminate")).getOrElse("false").toBoolean val stageId = Option(request.getParameter("id")).getOrElse("-1").toInt if (stageId >= 0 && killFlag && listener.activeStages.contains(stageId)) { diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala index dbf1ceeda1878..711a3697bda15 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala @@ -94,11 +94,11 @@ private[jobs] object UIData { var taskData = new HashMap[Long, TaskUIData] var executorSummary = new HashMap[String, ExecutorSummary] - def hasInput = inputBytes > 0 - def hasOutput = outputBytes > 0 - def hasShuffleRead = shuffleReadTotalBytes > 0 - def hasShuffleWrite = shuffleWriteBytes > 0 - def hasBytesSpilled = memoryBytesSpilled > 0 && diskBytesSpilled > 0 + def hasInput: Boolean = inputBytes > 0 + def hasOutput: Boolean = outputBytes > 0 + def hasShuffleRead: Boolean = shuffleReadTotalBytes > 0 + def hasShuffleWrite: Boolean = shuffleWriteBytes > 0 + def hasBytesSpilled: Boolean = memoryBytesSpilled > 0 && diskBytesSpilled > 0 } /** diff --git a/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala b/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala index a81291d505583..045bd784990d1 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala @@ -40,10 +40,10 @@ private[ui] class StorageTab(parent: SparkUI) extends SparkUITab(parent, "storag class StorageListener(storageStatusListener: StorageStatusListener) extends SparkListener { private[ui] val _rddInfoMap = mutable.Map[Int, RDDInfo]() // exposed for testing - def storageStatusList = storageStatusListener.storageStatusList + def storageStatusList: Seq[StorageStatus] = storageStatusListener.storageStatusList /** Filter RDD info to include only those with cached partitions */ - def rddInfoList = _rddInfoMap.values.filter(_.numCachedPartitions > 0).toSeq + def rddInfoList: Seq[RDDInfo] = _rddInfoMap.values.filter(_.numCachedPartitions > 0).toSeq /** Update the storage info of the RDDs whose blocks are among the given updated blocks */ private def updateRDDInfo(updatedBlocks: Seq[(BlockId, BlockStatus)]): Unit = { @@ -56,19 +56,19 @@ class StorageListener(storageStatusListener: StorageStatusListener) extends Spar * Assumes the storage status list is fully up-to-date. This implies the corresponding * StorageStatusSparkListener must process the SparkListenerTaskEnd event before this listener. */ - override def onTaskEnd(taskEnd: SparkListenerTaskEnd) = synchronized { + override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = synchronized { val metrics = taskEnd.taskMetrics if (metrics != null && metrics.updatedBlocks.isDefined) { updateRDDInfo(metrics.updatedBlocks.get) } } - override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) = synchronized { + override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted): Unit = synchronized { val rddInfos = stageSubmitted.stageInfo.rddInfos rddInfos.foreach { info => _rddInfoMap.getOrElseUpdate(info.id, info) } } - override def onStageCompleted(stageCompleted: SparkListenerStageCompleted) = synchronized { + override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = synchronized { // Remove all partitions that are no longer cached in current completed stage val completedRddIds = stageCompleted.stageInfo.rddInfos.map(r => r.id).toSet _rddInfoMap.retain { case (id, info) => @@ -76,7 +76,7 @@ class StorageListener(storageStatusListener: StorageStatusListener) extends Spar } } - override def onUnpersistRDD(unpersistRDD: SparkListenerUnpersistRDD) = synchronized { + override def onUnpersistRDD(unpersistRDD: SparkListenerUnpersistRDD): Unit = synchronized { _rddInfoMap.remove(unpersistRDD.rddId) } } diff --git a/core/src/main/scala/org/apache/spark/util/CompletionIterator.scala b/core/src/main/scala/org/apache/spark/util/CompletionIterator.scala index 390310243ee0a..9044aaeef2d48 100644 --- a/core/src/main/scala/org/apache/spark/util/CompletionIterator.scala +++ b/core/src/main/scala/org/apache/spark/util/CompletionIterator.scala @@ -27,8 +27,8 @@ abstract class CompletionIterator[ +A, +I <: Iterator[A]](sub: I) extends Iterat // scalastyle:on private[this] var completed = false - def next() = sub.next() - def hasNext = { + def next(): A = sub.next() + def hasNext: Boolean = { val r = sub.hasNext if (!r && !completed) { completed = true @@ -37,13 +37,13 @@ abstract class CompletionIterator[ +A, +I <: Iterator[A]](sub: I) extends Iterat r } - def completion() + def completion(): Unit } private[spark] object CompletionIterator { - def apply[A, I <: Iterator[A]](sub: I, completionFunction: => Unit) : CompletionIterator[A,I] = { + def apply[A, I <: Iterator[A]](sub: I, completionFunction: => Unit) : CompletionIterator[A, I] = { new CompletionIterator[A,I](sub) { - def completion() = completionFunction + def completion(): Unit = completionFunction } } } diff --git a/core/src/main/scala/org/apache/spark/util/Distribution.scala b/core/src/main/scala/org/apache/spark/util/Distribution.scala index a465298c8c5ab..9aea8efa38c7a 100644 --- a/core/src/main/scala/org/apache/spark/util/Distribution.scala +++ b/core/src/main/scala/org/apache/spark/util/Distribution.scala @@ -57,7 +57,7 @@ private[spark] class Distribution(val data: Array[Double], val startIdx: Int, va out.println } - def statCounter = StatCounter(data.slice(startIdx, endIdx)) + def statCounter: StatCounter = StatCounter(data.slice(startIdx, endIdx)) /** * print a summary of this distribution to the given PrintStream. diff --git a/core/src/main/scala/org/apache/spark/util/ManualClock.scala b/core/src/main/scala/org/apache/spark/util/ManualClock.scala index cf89c1782fd67..1718554061985 100644 --- a/core/src/main/scala/org/apache/spark/util/ManualClock.scala +++ b/core/src/main/scala/org/apache/spark/util/ManualClock.scala @@ -39,31 +39,27 @@ private[spark] class ManualClock(private var time: Long) extends Clock { /** * @param timeToSet new time (in milliseconds) that the clock should represent */ - def setTime(timeToSet: Long) = - synchronized { - time = timeToSet - notifyAll() - } + def setTime(timeToSet: Long): Unit = synchronized { + time = timeToSet + notifyAll() + } /** * @param timeToAdd time (in milliseconds) to add to the clock's time */ - def advance(timeToAdd: Long) = - synchronized { - time += timeToAdd - notifyAll() - } + def advance(timeToAdd: Long): Unit = synchronized { + time += timeToAdd + notifyAll() + } /** * @param targetTime block until the clock time is set or advanced to at least this time * @return current time reported by the clock when waiting finishes */ - def waitTillTime(targetTime: Long): Long = - synchronized { - while (time < targetTime) { - wait(100) - } - getTimeMillis() + def waitTillTime(targetTime: Long): Long = synchronized { + while (time < targetTime) { + wait(100) } - + getTimeMillis() + } } diff --git a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala index ac40f19ed6799..375ed430bde45 100644 --- a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala +++ b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala @@ -67,14 +67,15 @@ private[spark] object MetadataCleanerType extends Enumeration { type MetadataCleanerType = Value - def systemProperty(which: MetadataCleanerType.MetadataCleanerType) = - "spark.cleaner.ttl." + which.toString + def systemProperty(which: MetadataCleanerType.MetadataCleanerType): String = { + "spark.cleaner.ttl." + which.toString + } } // TODO: This mutates a Conf to set properties right now, which is kind of ugly when used in the // initialization of StreamingContext. It's okay for users trying to configure stuff themselves. private[spark] object MetadataCleaner { - def getDelaySeconds(conf: SparkConf) = { + def getDelaySeconds(conf: SparkConf): Int = { conf.getInt("spark.cleaner.ttl", -1) } diff --git a/core/src/main/scala/org/apache/spark/util/MutablePair.scala b/core/src/main/scala/org/apache/spark/util/MutablePair.scala index 74fa77b68de0b..dad888548ed10 100644 --- a/core/src/main/scala/org/apache/spark/util/MutablePair.scala +++ b/core/src/main/scala/org/apache/spark/util/MutablePair.scala @@ -43,7 +43,7 @@ case class MutablePair[@specialized(Int, Long, Double, Char, Boolean/* , AnyRef this } - override def toString = "(" + _1 + "," + _2 + ")" + override def toString: String = "(" + _1 + "," + _2 + ")" override def canEqual(that: Any): Boolean = that.isInstanceOf[MutablePair[_,_]] } diff --git a/core/src/main/scala/org/apache/spark/util/ParentClassLoader.scala b/core/src/main/scala/org/apache/spark/util/ParentClassLoader.scala index 6d8d9e8da3678..73d126ff6254e 100644 --- a/core/src/main/scala/org/apache/spark/util/ParentClassLoader.scala +++ b/core/src/main/scala/org/apache/spark/util/ParentClassLoader.scala @@ -22,7 +22,7 @@ package org.apache.spark.util */ private[spark] class ParentClassLoader(parent: ClassLoader) extends ClassLoader(parent) { - override def findClass(name: String) = { + override def findClass(name: String): Class[_] = { super.findClass(name) } diff --git a/core/src/main/scala/org/apache/spark/util/SerializableBuffer.scala b/core/src/main/scala/org/apache/spark/util/SerializableBuffer.scala index 770ff9d5ad6ae..a06b6f84ef11b 100644 --- a/core/src/main/scala/org/apache/spark/util/SerializableBuffer.scala +++ b/core/src/main/scala/org/apache/spark/util/SerializableBuffer.scala @@ -27,7 +27,7 @@ import java.nio.channels.Channels */ private[spark] class SerializableBuffer(@transient var buffer: ByteBuffer) extends Serializable { - def value = buffer + def value: ByteBuffer = buffer private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException { val length = in.readInt() diff --git a/core/src/main/scala/org/apache/spark/util/StatCounter.scala b/core/src/main/scala/org/apache/spark/util/StatCounter.scala index d80eed455c427..8586da1996cf3 100644 --- a/core/src/main/scala/org/apache/spark/util/StatCounter.scala +++ b/core/src/main/scala/org/apache/spark/util/StatCounter.scala @@ -141,8 +141,8 @@ class StatCounter(values: TraversableOnce[Double]) extends Serializable { object StatCounter { /** Build a StatCounter from a list of values. */ - def apply(values: TraversableOnce[Double]) = new StatCounter(values) + def apply(values: TraversableOnce[Double]): StatCounter = new StatCounter(values) /** Build a StatCounter from a list of values passed as variable-length arguments. */ - def apply(values: Double*) = new StatCounter(values) + def apply(values: Double*): StatCounter = new StatCounter(values) } diff --git a/core/src/main/scala/org/apache/spark/util/TimeStampedWeakValueHashMap.scala b/core/src/main/scala/org/apache/spark/util/TimeStampedWeakValueHashMap.scala index f5be5856c2109..310c0c109416c 100644 --- a/core/src/main/scala/org/apache/spark/util/TimeStampedWeakValueHashMap.scala +++ b/core/src/main/scala/org/apache/spark/util/TimeStampedWeakValueHashMap.scala @@ -82,7 +82,7 @@ private[spark] class TimeStampedWeakValueHashMap[A, B](updateTimeStampOnGet: Boo this } - override def update(key: A, value: B) = this += ((key, value)) + override def update(key: A, value: B): Unit = this += ((key, value)) override def apply(key: A): B = internalMap.apply(key) @@ -92,14 +92,14 @@ private[spark] class TimeStampedWeakValueHashMap[A, B](updateTimeStampOnGet: Boo override def size: Int = internalMap.size - override def foreach[U](f: ((A, B)) => U) = nonNullReferenceMap.foreach(f) + override def foreach[U](f: ((A, B)) => U): Unit = nonNullReferenceMap.foreach(f) def putIfAbsent(key: A, value: B): Option[B] = internalMap.putIfAbsent(key, value) def toMap: Map[A, B] = iterator.toMap /** Remove old key-value pairs with timestamps earlier than `threshTime`. */ - def clearOldValues(threshTime: Long) = internalMap.clearOldValues(threshTime) + def clearOldValues(threshTime: Long): Unit = internalMap.clearOldValues(threshTime) /** Remove entries with values that are no longer strongly reachable. */ def clearNullValues() { diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index fa56bb09e4e5c..d9a671687aad0 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -85,7 +85,7 @@ private[spark] object Utils extends Logging { def deserialize[T](bytes: Array[Byte], loader: ClassLoader): T = { val bis = new ByteArrayInputStream(bytes) val ois = new ObjectInputStream(bis) { - override def resolveClass(desc: ObjectStreamClass) = + override def resolveClass(desc: ObjectStreamClass): Class[_] = Class.forName(desc.getName, false, loader) } ois.readObject.asInstanceOf[T] @@ -106,11 +106,10 @@ private[spark] object Utils extends Logging { /** Serialize via nested stream using specific serializer */ def serializeViaNestedStream(os: OutputStream, ser: SerializerInstance)( - f: SerializationStream => Unit) = { + f: SerializationStream => Unit): Unit = { val osWrapper = ser.serializeStream(new OutputStream { - def write(b: Int) = os.write(b) - - override def write(b: Array[Byte], off: Int, len: Int) = os.write(b, off, len) + override def write(b: Int): Unit = os.write(b) + override def write(b: Array[Byte], off: Int, len: Int): Unit = os.write(b, off, len) }) try { f(osWrapper) @@ -121,10 +120,9 @@ private[spark] object Utils extends Logging { /** Deserialize via nested stream using specific serializer */ def deserializeViaNestedStream(is: InputStream, ser: SerializerInstance)( - f: DeserializationStream => Unit) = { + f: DeserializationStream => Unit): Unit = { val isWrapper = ser.deserializeStream(new InputStream { - def read(): Int = is.read() - + override def read(): Int = is.read() override def read(b: Array[Byte], off: Int, len: Int): Int = is.read(b, off, len) }) try { @@ -137,7 +135,7 @@ private[spark] object Utils extends Logging { /** * Get the ClassLoader which loaded Spark. */ - def getSparkClassLoader = getClass.getClassLoader + def getSparkClassLoader: ClassLoader = getClass.getClassLoader /** * Get the Context ClassLoader on this thread or, if not present, the ClassLoader that @@ -146,7 +144,7 @@ private[spark] object Utils extends Logging { * This should be used whenever passing a ClassLoader to Class.ForName or finding the currently * active loader when setting up ClassLoader delegation chains. */ - def getContextOrSparkClassLoader = + def getContextOrSparkClassLoader: ClassLoader = Option(Thread.currentThread().getContextClassLoader).getOrElse(getSparkClassLoader) /** Determines whether the provided class is loadable in the current thread. */ @@ -155,12 +153,14 @@ private[spark] object Utils extends Logging { } /** Preferred alternative to Class.forName(className) */ - def classForName(className: String) = Class.forName(className, true, getContextOrSparkClassLoader) + def classForName(className: String): Class[_] = { + Class.forName(className, true, getContextOrSparkClassLoader) + } /** * Primitive often used when writing [[java.nio.ByteBuffer]] to [[java.io.DataOutput]] */ - def writeByteBuffer(bb: ByteBuffer, out: ObjectOutput) = { + def writeByteBuffer(bb: ByteBuffer, out: ObjectOutput): Unit = { if (bb.hasArray) { out.write(bb.array(), bb.arrayOffset() + bb.position(), bb.remaining()) } else { @@ -1557,7 +1557,7 @@ private[spark] object Utils extends Logging { /** Return the class name of the given object, removing all dollar signs */ - def getFormattedClassName(obj: AnyRef) = { + def getFormattedClassName(obj: AnyRef): String = { obj.getClass.getSimpleName.replace("$", "") } @@ -1570,7 +1570,7 @@ private[spark] object Utils extends Logging { } /** Return an empty JSON object */ - def emptyJson = JObject(List[JField]()) + def emptyJson: JsonAST.JObject = JObject(List[JField]()) /** * Return a Hadoop FileSystem with the scheme encoded in the given path. @@ -1618,7 +1618,7 @@ private[spark] object Utils extends Logging { /** * Indicates whether Spark is currently running unit tests. */ - def isTesting = { + def isTesting: Boolean = { sys.env.contains("SPARK_TESTING") || sys.props.contains("spark.testing") } diff --git a/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala b/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala index af1f64649f354..f79e8e0491ea1 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala @@ -156,10 +156,10 @@ class BitSet(numBits: Int) extends Serializable { /** * Get an iterator over the set bits. */ - def iterator = new Iterator[Int] { + def iterator: Iterator[Int] = new Iterator[Int] { var ind = nextSetBit(0) override def hasNext: Boolean = ind >= 0 - override def next() = { + override def next(): Int = { val tmp = ind ind = nextSetBit(ind + 1) tmp diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala index 8a0f5a602de12..9ff4744593d4d 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala @@ -159,7 +159,7 @@ class ExternalAppendOnlyMap[K, V, C]( val batchSizes = new ArrayBuffer[Long] // Flush the disk writer's contents to disk, and update relevant variables - def flush() = { + def flush(): Unit = { val w = writer writer = null w.commitAndClose() @@ -355,7 +355,7 @@ class ExternalAppendOnlyMap[K, V, C]( val pairs: ArrayBuffer[(K, C)]) extends Comparable[StreamBuffer] { - def isEmpty = pairs.length == 0 + def isEmpty: Boolean = pairs.length == 0 // Invalid if there are no more pairs in this stream def minKeyHash: Int = { diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index d69f2d9048055..3262e670c2030 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -283,7 +283,7 @@ private[spark] class ExternalSorter[K, V, C]( // Flush the disk writer's contents to disk, and update relevant variables. // The writer is closed at the end of this process, and cannot be reused. - def flush() = { + def flush(): Unit = { val w = writer writer = null w.commitAndClose() diff --git a/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala b/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala index b8de4ff9aa494..c52591b352340 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala @@ -109,7 +109,7 @@ class OpenHashMap[K : ClassTag, @specialized(Long, Int, Double) V: ClassTag]( } } - override def iterator = new Iterator[(K, V)] { + override def iterator: Iterator[(K, V)] = new Iterator[(K, V)] { var pos = -1 var nextPair: (K, V) = computeNextPair() @@ -132,9 +132,9 @@ class OpenHashMap[K : ClassTag, @specialized(Long, Int, Double) V: ClassTag]( } } - def hasNext = nextPair != null + def hasNext: Boolean = nextPair != null - def next() = { + def next(): (K, V) = { val pair = nextPair nextPair = computeNextPair() pair diff --git a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala index 4e363b74f4bef..c80057f95e0b2 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala @@ -85,7 +85,7 @@ class OpenHashSet[@specialized(Long, Int) T: ClassTag]( protected var _bitset = new BitSet(_capacity) - def getBitSet = _bitset + def getBitSet: BitSet = _bitset // Init of the array in constructor (instead of in declaration) to work around a Scala compiler // specialization bug that would generate two arrays (one for Object and one for specialized T). @@ -183,7 +183,7 @@ class OpenHashSet[@specialized(Long, Int) T: ClassTag]( /** Return the value at the specified position. */ def getValue(pos: Int): T = _data(pos) - def iterator = new Iterator[T] { + def iterator: Iterator[T] = new Iterator[T] { var pos = nextPos(0) override def hasNext: Boolean = pos != INVALID_POS override def next(): T = { diff --git a/core/src/main/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMap.scala b/core/src/main/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMap.scala index 2e1ef06cbc4e1..61e22642761f0 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMap.scala @@ -46,7 +46,7 @@ class PrimitiveKeyOpenHashMap[@specialized(Long, Int) K: ClassTag, private var _oldValues: Array[V] = null - override def size = _keySet.size + override def size: Int = _keySet.size /** Get the value for a given key */ def apply(k: K): V = { @@ -87,7 +87,7 @@ class PrimitiveKeyOpenHashMap[@specialized(Long, Int) K: ClassTag, } } - override def iterator = new Iterator[(K, V)] { + override def iterator: Iterator[(K, V)] = new Iterator[(K, V)] { var pos = 0 var nextPair: (K, V) = computeNextPair() @@ -103,9 +103,9 @@ class PrimitiveKeyOpenHashMap[@specialized(Long, Int) K: ClassTag, } } - def hasNext = nextPair != null + def hasNext: Boolean = nextPair != null - def next() = { + def next(): (K, V) = { val pair = nextPair nextPair = computeNextPair() pair diff --git a/core/src/main/scala/org/apache/spark/util/collection/Utils.scala b/core/src/main/scala/org/apache/spark/util/collection/Utils.scala index c5268c0fae0ef..bdbca00a00622 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/Utils.scala @@ -32,7 +32,7 @@ private[spark] object Utils { */ def takeOrdered[T](input: Iterator[T], num: Int)(implicit ord: Ordering[T]): Iterator[T] = { val ordering = new GuavaOrdering[T] { - override def compare(l: T, r: T) = ord.compare(l, r) + override def compare(l: T, r: T): Int = ord.compare(l, r) } collectionAsScalaIterable(ordering.leastOf(asJavaIterator(input), num)).iterator } diff --git a/core/src/main/scala/org/apache/spark/util/logging/FileAppender.scala b/core/src/main/scala/org/apache/spark/util/logging/FileAppender.scala index 1d5467060623c..14b6ba4af489a 100644 --- a/core/src/main/scala/org/apache/spark/util/logging/FileAppender.scala +++ b/core/src/main/scala/org/apache/spark/util/logging/FileAppender.scala @@ -121,7 +121,7 @@ private[spark] object FileAppender extends Logging { val rollingSizeBytes = conf.get(SIZE_PROPERTY, STRATEGY_DEFAULT) val rollingInterval = conf.get(INTERVAL_PROPERTY, INTERVAL_DEFAULT) - def createTimeBasedAppender() = { + def createTimeBasedAppender(): FileAppender = { val validatedParams: Option[(Long, String)] = rollingInterval match { case "daily" => logInfo(s"Rolling executor logs enabled for $file with daily rolling") @@ -149,7 +149,7 @@ private[spark] object FileAppender extends Logging { } } - def createSizeBasedAppender() = { + def createSizeBasedAppender(): FileAppender = { rollingSizeBytes match { case IntParam(bytes) => logInfo(s"Rolling executor logs enabled for $file with rolling every $bytes bytes") diff --git a/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala b/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala index 76e7a2760bcd1..786b97ad7b9ec 100644 --- a/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala +++ b/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala @@ -105,7 +105,7 @@ class BernoulliCellSampler[T](lb: Double, ub: Double, complement: Boolean = fals private val rng: Random = new XORShiftRandom - override def setSeed(seed: Long) = rng.setSeed(seed) + override def setSeed(seed: Long): Unit = rng.setSeed(seed) override def sample(items: Iterator[T]): Iterator[T] = { if (ub - lb <= 0.0) { @@ -131,7 +131,7 @@ class BernoulliCellSampler[T](lb: Double, ub: Double, complement: Boolean = fals def cloneComplement(): BernoulliCellSampler[T] = new BernoulliCellSampler[T](lb, ub, !complement) - override def clone = new BernoulliCellSampler[T](lb, ub, complement) + override def clone: BernoulliCellSampler[T] = new BernoulliCellSampler[T](lb, ub, complement) } @@ -153,7 +153,7 @@ class BernoulliSampler[T: ClassTag](fraction: Double) extends RandomSampler[T, T private val rng: Random = RandomSampler.newDefaultRNG - override def setSeed(seed: Long) = rng.setSeed(seed) + override def setSeed(seed: Long): Unit = rng.setSeed(seed) override def sample(items: Iterator[T]): Iterator[T] = { if (fraction <= 0.0) { @@ -167,7 +167,7 @@ class BernoulliSampler[T: ClassTag](fraction: Double) extends RandomSampler[T, T } } - override def clone = new BernoulliSampler[T](fraction) + override def clone: BernoulliSampler[T] = new BernoulliSampler[T](fraction) } @@ -209,7 +209,7 @@ class PoissonSampler[T: ClassTag](fraction: Double) extends RandomSampler[T, T] } } - override def clone = new PoissonSampler[T](fraction) + override def clone: PoissonSampler[T] = new PoissonSampler[T](fraction) } @@ -228,15 +228,18 @@ class GapSamplingIterator[T: ClassTag]( val arrayClass = Array.empty[T].iterator.getClass val arrayBufferClass = ArrayBuffer.empty[T].iterator.getClass data.getClass match { - case `arrayClass` => ((n: Int) => { data = data.drop(n) }) - case `arrayBufferClass` => ((n: Int) => { data = data.drop(n) }) - case _ => ((n: Int) => { + case `arrayClass` => + (n: Int) => { data = data.drop(n) } + case `arrayBufferClass` => + (n: Int) => { data = data.drop(n) } + case _ => + (n: Int) => { var j = 0 while (j < n && data.hasNext) { data.next() j += 1 } - }) + } } } @@ -244,21 +247,21 @@ class GapSamplingIterator[T: ClassTag]( override def next(): T = { val r = data.next() - advance + advance() r } private val lnq = math.log1p(-f) /** skip elements that won't be sampled, according to geometric dist P(k) = (f)(1-f)^k. */ - private def advance: Unit = { + private def advance(): Unit = { val u = math.max(rng.nextDouble(), epsilon) val k = (math.log(u) / lnq).toInt iterDrop(k) } /** advance to first sample as part of object construction. */ - advance + advance() // Attempting to invoke this closer to the top with other object initialization // was causing it to break in strange ways, so I'm invoking it last, which seems to // work reliably. @@ -279,15 +282,18 @@ class GapSamplingReplacementIterator[T: ClassTag]( val arrayClass = Array.empty[T].iterator.getClass val arrayBufferClass = ArrayBuffer.empty[T].iterator.getClass data.getClass match { - case `arrayClass` => ((n: Int) => { data = data.drop(n) }) - case `arrayBufferClass` => ((n: Int) => { data = data.drop(n) }) - case _ => ((n: Int) => { + case `arrayClass` => + (n: Int) => { data = data.drop(n) } + case `arrayBufferClass` => + (n: Int) => { data = data.drop(n) } + case _ => + (n: Int) => { var j = 0 while (j < n && data.hasNext) { data.next() j += 1 } - }) + } } } @@ -300,7 +306,7 @@ class GapSamplingReplacementIterator[T: ClassTag]( override def next(): T = { val r = v rep -= 1 - if (rep <= 0) advance + if (rep <= 0) advance() r } @@ -309,7 +315,7 @@ class GapSamplingReplacementIterator[T: ClassTag]( * Samples 'k' from geometric distribution P(k) = (1-q)(q)^k, where q = e^(-f), that is * q is the probabililty of Poisson(0; f) */ - private def advance: Unit = { + private def advance(): Unit = { val u = math.max(rng.nextDouble(), epsilon) val k = (math.log(u) / (-f)).toInt iterDrop(k) @@ -343,7 +349,7 @@ class GapSamplingReplacementIterator[T: ClassTag]( } /** advance to first sample as part of object construction. */ - advance + advance() // Attempting to invoke this closer to the top with other object initialization // was causing it to break in strange ways, so I'm invoking it last, which seems to // work reliably. diff --git a/core/src/main/scala/org/apache/spark/util/random/StratifiedSamplingUtils.scala b/core/src/main/scala/org/apache/spark/util/random/StratifiedSamplingUtils.scala index 2ae308dacf1ae..9e29bf9d61f17 100644 --- a/core/src/main/scala/org/apache/spark/util/random/StratifiedSamplingUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/random/StratifiedSamplingUtils.scala @@ -311,7 +311,7 @@ private[random] class AcceptanceResult(var numItems: Long = 0L, var numAccepted: var acceptBound: Double = Double.NaN // upper bound for accepting item instantly var waitListBound: Double = Double.NaN // upper bound for adding item to waitlist - def areBoundsEmpty = acceptBound.isNaN || waitListBound.isNaN + def areBoundsEmpty: Boolean = acceptBound.isNaN || waitListBound.isNaN def merge(other: Option[AcceptanceResult]): Unit = { if (other.isDefined) { diff --git a/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala b/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala index 467b890fb4bb9..c4a7b4441c85c 100644 --- a/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala +++ b/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala @@ -83,7 +83,7 @@ private[spark] object XORShiftRandom { * @return Map of execution times for {@link java.util.Random java.util.Random} * and XORShift */ - def benchmark(numIters: Int) = { + def benchmark(numIters: Int): Map[String, Long] = { val seed = 1L val million = 1e6.toInt diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 328d59485a731..56f5dbe53fad4 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -44,7 +44,13 @@ object MimaExcludes { // the maven-generated artifacts in 1.3. excludePackage("org.spark-project.jetty"), MimaBuild.excludeSparkPackage("unused"), - ProblemFilters.exclude[MissingClassProblem]("com.google.common.base.Optional") + ProblemFilters.exclude[MissingClassProblem]("com.google.common.base.Optional"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]( + "org.apache.spark.rdd.JdbcRDD.compute"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]( + "org.apache.spark.broadcast.HttpBroadcastFactory.newBroadcast"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]( + "org.apache.spark.broadcast.TorrentBroadcastFactory.newBroadcast") ) case v if v.startsWith("1.3") => diff --git a/scalastyle-config.xml b/scalastyle-config.xml index 0ff521706c71a..459a5035d4984 100644 --- a/scalastyle-config.xml +++ b/scalastyle-config.xml @@ -137,9 +137,9 @@ - + - + From 1afcf773d0cafdfd9bf106fdc5c429ed2ba3dd36 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Tue, 24 Mar 2015 01:12:11 -0700 Subject: [PATCH 47/47] [SPARK-6452] [SQL] Checks for missing attributes and unresolved operator for all types of operator In `CheckAnalysis`, `Filter` and `Aggregate` are checked in separate case clauses, thus never hit those clauses for unresolved operators and missing input attributes. This PR also removes the `prettyString` call when generating error message for missing input attributes. Because result of `prettyString` doesn't contain expression ID, and may give confusing messages like > resolved attributes a missing from a cc rxin [Review on Reviewable](https://reviewable.io/reviews/apache/spark/5129) Author: Cheng Lian Closes #5129 from liancheng/spark-6452 and squashes the following commits: 52cdc69 [Cheng Lian] Addresses comments 029f9bd [Cheng Lian] Checks for missing attributes and unresolved operator for all types of operator --- .../sql/catalyst/analysis/CheckAnalysis.scala | 15 ++++++++++----- .../spark/sql/catalyst/plans/QueryPlan.scala | 7 +++++-- .../sql/catalyst/analysis/AnalysisSuite.scala | 18 ++++++++++++++++++ 3 files changed, 33 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index fb975ee5e7296..425e1e41cbf21 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -63,7 +63,7 @@ class CheckAnalysis { s"filter expression '${f.condition.prettyString}' " + s"of type ${f.condition.dataType.simpleString} is not a boolean.") - case aggregatePlan@Aggregate(groupingExprs, aggregateExprs, child) => + case Aggregate(groupingExprs, aggregateExprs, child) => def checkValidAggregateExpression(expr: Expression): Unit = expr match { case _: AggregateExpression => // OK case e: Attribute if !groupingExprs.contains(e) => @@ -85,13 +85,18 @@ class CheckAnalysis { cleaned.foreach(checkValidAggregateExpression) + case _ => // Fallbacks to the following checks + } + + operator match { case o if o.children.nonEmpty && o.missingInput.nonEmpty => - val missingAttributes = o.missingInput.map(_.prettyString).mkString(",") - val input = o.inputSet.map(_.prettyString).mkString(",") + val missingAttributes = o.missingInput.mkString(",") + val input = o.inputSet.mkString(",") - failAnalysis(s"resolved attributes $missingAttributes missing from $input") + failAnalysis( + s"resolved attribute(s) $missingAttributes missing from $input " + + s"in operator ${operator.simpleString}") - // Catch all case o if !o.resolved => failAnalysis( s"unresolved operator ${operator.simpleString}") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 400a6b2825c10..48191f31198f3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -47,9 +47,12 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy * Attributes that are referenced by expressions but not provided by this nodes children. * Subclasses should override this method if they produce attributes internally as it is used by * assertions designed to prevent the construction of invalid plans. + * + * Note that virtual columns should be excluded. Currently, we only support the grouping ID + * virtual column. */ - def missingInput: AttributeSet = (references -- inputSet) - .filter(_.name != VirtualColumn.groupingIdName) + def missingInput: AttributeSet = + (references -- inputSet).filter(_.name != VirtualColumn.groupingIdName) /** * Runs [[transform]] with `rule` on all expressions present in this query operator. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index c1dd5aa913ddc..359aec4a7b5ab 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -199,4 +199,22 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter { assert(pl(3).dataType == DecimalType.Unlimited) assert(pl(4).dataType == DoubleType) } + + test("SPARK-6452 regression test") { + // CheckAnalysis should throw AnalysisException when Aggregate contains missing attribute(s) + val plan = + Aggregate( + Nil, + Alias(Sum(AttributeReference("a", StringType)(exprId = ExprId(1))), "b")() :: Nil, + LocalRelation( + AttributeReference("a", StringType)(exprId = ExprId(2)))) + + assert(plan.resolved) + + val message = intercept[AnalysisException] { + caseSensitiveAnalyze(plan) + }.getMessage + + assert(message.contains("resolved attribute(s) a#1 missing from a#2")) + } }