From e22110879cd149e94c9a5ca7466f787033572b15 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Sat, 2 Aug 2014 12:11:50 -0700 Subject: [PATCH 001/231] [HOTFIX] Do not throw NPE if spark.test.home is not set `spark.test.home` was introduced in #1734. This is fine for SBT but is failing maven tests. Either way it shouldn't throw an NPE. Author: Andrew Or Closes #1739 from andrewor14/fix-spark-test-home and squashes the following commits: ce2624c [Andrew Or] Do not throw NPE if spark.test.home is not set --- .../scala/org/apache/spark/deploy/worker/Worker.scala | 9 +++++++-- core/src/test/scala/org/apache/spark/DriverSuite.scala | 2 +- .../scala/org/apache/spark/deploy/SparkSubmitSuite.scala | 2 +- .../apache/spark/deploy/worker/ExecutorRunnerTest.scala | 2 +- pom.xml | 8 ++++---- 5 files changed, 14 insertions(+), 9 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index c6ea42fceb659..458d9947bd873 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -71,7 +71,7 @@ private[spark] class Worker( // TTL for app folders/data; after TTL expires it will be cleaned up val APP_DATA_RETENTION_SECS = conf.getLong("spark.worker.cleanup.appDataTtl", 7 * 24 * 3600) - + val testing: Boolean = sys.props.contains("spark.testing") val masterLock: Object = new Object() var master: ActorSelection = null var masterAddress: Address = null @@ -82,7 +82,12 @@ private[spark] class Worker( @volatile var connected = false val workerId = generateWorkerId() val sparkHome = - new File(sys.props.get("spark.test.home").orElse(sys.env.get("SPARK_HOME")).getOrElse(".")) + if (testing) { + assert(sys.props.contains("spark.test.home"), "spark.test.home is not set!") + new File(sys.props("spark.test.home")) + } else { + new File(sys.env.get("SPARK_HOME").getOrElse(".")) + } var workDir: File = null val executors = new HashMap[String, ExecutorRunner] val finishedExecutors = new HashMap[String, ExecutorRunner] diff --git a/core/src/test/scala/org/apache/spark/DriverSuite.scala b/core/src/test/scala/org/apache/spark/DriverSuite.scala index e36902ec81e08..a73e1ef0288a5 100644 --- a/core/src/test/scala/org/apache/spark/DriverSuite.scala +++ b/core/src/test/scala/org/apache/spark/DriverSuite.scala @@ -34,7 +34,7 @@ import scala.language.postfixOps class DriverSuite extends FunSuite with Timeouts { test("driver should exit after finishing") { - val sparkHome = sys.props("spark.test.home") + val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!")) // Regression test for SPARK-530: "Spark driver process doesn't exit after finishing" val masters = Table(("master"), ("local"), ("local-cluster[2,1,512]")) forAll(masters) { (master: String) => diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 8126ef1bb23aa..a5cdcfb5de03b 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -295,7 +295,7 @@ class SparkSubmitSuite extends FunSuite with Matchers { // NOTE: This is an expensive operation in terms of time (10 seconds+). Use sparingly. def runSparkSubmit(args: Seq[String]): String = { - val sparkHome = sys.props("spark.test.home") + val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!")) Utils.executeAndGetOutput( Seq("./bin/spark-submit") ++ args, new File(sparkHome), diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala b/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala index 149a2b3d95b86..39ab53cf0b5b1 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala @@ -27,7 +27,7 @@ import org.apache.spark.SparkConf class ExecutorRunnerTest extends FunSuite { test("command includes appId") { def f(s:String) = new File(s) - val sparkHome = sys.props("spark.test.home") + val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!")) val appDesc = new ApplicationDescription("app name", Some(8), 500, Command("foo", Seq(), Map(), Seq(), Seq(), Seq()), "appUiUrl") val appId = "12345-worker321-9876" diff --git a/pom.xml b/pom.xml index ae97bf03c53a2..99ae4b8b33f94 100644 --- a/pom.xml +++ b/pom.xml @@ -868,10 +868,10 @@ ${project.build.directory}/SparkTestSuite.txt -Xmx3g -XX:MaxPermSize=${MaxPermGen} -XX:ReservedCodeCacheSize=512m - - ${session.executionRootDirectory} - 1 - + + ${session.executionRootDirectory} + 1 + From 8d6ac2b95ab48d9fffe82ef04cef3b22c2c139e0 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Sat, 2 Aug 2014 13:07:17 -0700 Subject: [PATCH 002/231] [SPARK-2478] [mllib] DecisionTree Python API Added experimental Python API for Decision Trees. API: * class DecisionTreeModel ** predict() for single examples and RDDs, taking both feature vectors and LabeledPoints ** numNodes() ** depth() ** __str__() * class DecisionTree ** trainClassifier() ** trainRegressor() ** train() Examples and testing: * Added example testing classification and regression with batch prediction: examples/src/main/python/mllib/tree.py * Have also tested example usage in doc of python/pyspark/mllib/tree.py which tests single-example prediction with dense and sparse vectors Also: Small bug fix in python/pyspark/mllib/_common.py: In _linear_predictor_typecheck, changed check for RDD to use isinstance() instead of type() in order to catch RDD subclasses. CC mengxr manishamde Author: Joseph K. Bradley Closes #1727 from jkbradley/decisiontree-python-new and squashes the following commits: 3744488 [Joseph K. Bradley] Renamed test tree.py to decision_tree_runner.py Small updates based on github review. 6b86a9d [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into decisiontree-python-new affceb9 [Joseph K. Bradley] * Fixed bug in doc tests in pyspark/mllib/util.py caused by change in loadLibSVMFile behavior. (It used to threshold labels at 0 to make them 0/1, but it now leaves them as they are.) * Fixed small bug in loadLibSVMFile: If a data file had no features, then loadLibSVMFile would create a single all-zero feature. 67a29bc [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into decisiontree-python-new cf46ad7 [Joseph K. Bradley] Python DecisionTreeModel * predict(empty RDD) returns an empty RDD instead of an error. * Removed support for calling predict() on LabeledPoint and RDD[LabeledPoint] * predict() does not cache serialized RDD any more. aa29873 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into decisiontree-python-new bf21be4 [Joseph K. Bradley] removed old run() func from DecisionTree fa10ea7 [Joseph K. Bradley] Small style update 7968692 [Joseph K. Bradley] small braces typo fix e34c263 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into decisiontree-python-new 4801b40 [Joseph K. Bradley] Small style update to DecisionTreeSuite db0eab2 [Joseph K. Bradley] Merge branch 'decisiontree-bugfix2' into decisiontree-python-new 6873fa9 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into decisiontree-python-new 225822f [Joseph K. Bradley] Bug: In DecisionTree, the method sequentialBinSearchForOrderedCategoricalFeatureInClassification() indexed bins from 0 to (math.pow(2, featureCategories.toInt - 1) - 1). This upper bound is the bound for unordered categorical features, not ordered ones. The upper bound should be the arity (i.e., max value) of the feature. 93953f1 [Joseph K. Bradley] Likely done with Python API. 6df89a9 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into decisiontree-python-new 4562c08 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into decisiontree-python-new 665ba78 [Joseph K. Bradley] Small updates towards Python DecisionTree API 188cb0d [Joseph K. Bradley] Merge branch 'decisiontree-bugfix' into decisiontree-python-new 6622247 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into decisiontree-python-new b8fac57 [Joseph K. Bradley] Finished Python DecisionTree API and example but need to test a bit more. 2b20c61 [Joseph K. Bradley] Small doc and style updates 1b29c13 [Joseph K. Bradley] Merge branch 'decisiontree-bugfix' into decisiontree-python-new 584449a [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into decisiontree-python-new dab0b67 [Joseph K. Bradley] Added documentation for DecisionTree internals 8bb8aa0 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into decisiontree-bugfix 978cfcf [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into decisiontree-bugfix 6eed482 [Joseph K. Bradley] In DecisionTree: Changed from using procedural syntax for functions returning Unit to explicitly writing Unit return type. 376dca2 [Joseph K. Bradley] Updated meaning of maxDepth by 1 to fit scikit-learn and rpart. * In code, replaced usages of maxDepth <-- maxDepth + 1 * In params, replace settings of maxDepth <-- maxDepth - 1 e06e423 [Joseph K. Bradley] Merge branch 'decisiontree-bugfix' into decisiontree-python-new bab3f19 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into decisiontree-python-new 59750f8 [Joseph K. Bradley] * Updated Strategy to check numClassesForClassification only if algo=Classification. * Updates based on comments: ** DecisionTreeRunner *** Made dataFormat arg default to libsvm ** Small cleanups ** tree.Node: Made recursive helper methods private, and renamed them. 52e17c5 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into decisiontree-bugfix f5a036c [Joseph K. Bradley] Merge branch 'decisiontree-bugfix' into decisiontree-python-new da50db7 [Joseph K. Bradley] Added one more test to DecisionTreeSuite: stump with 2 continuous variables for binary classification. Caused problems in past, but fixed now. 8e227ea [Joseph K. Bradley] Changed Strategy so it only requires numClassesForClassification >= 2 for classification cd1d933 [Joseph K. Bradley] Merge branch 'decisiontree-bugfix' into decisiontree-python-new 8ea8750 [Joseph K. Bradley] Bug fix: Off-by-1 when finding thresholds for splits for continuous features. 8a758db [Joseph K. Bradley] Merge branch 'decisiontree-bugfix' into decisiontree-python-new 5fe44ed [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into decisiontree-python-new 2283df8 [Joseph K. Bradley] 2 bug fixes. 73fbea2 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into decisiontree-bugfix 5f920a1 [Joseph K. Bradley] Demonstration of bug before submitting fix: Updated DecisionTreeSuite so that 3 tests fail. Will describe bug in next commit. f825352 [Joseph K. Bradley] Wrote Python API and example for DecisionTree. Also added toString, depth, and numNodes methods to DecisionTreeModel. (cherry picked from commit 3f67382e7c9c3f6a8f6ce124ab3fcb1a9c1a264f) Signed-off-by: Xiangrui Meng --- .../main/python/mllib/decision_tree_runner.py | 133 +++++++++++ .../main/python/mllib/logistic_regression.py | 4 +- .../mllib/api/python/PythonMLLibAPI.scala | 78 ++++++ .../mllib/tree/configuration/Strategy.scala | 3 +- .../spark/mllib/tree/DecisionTreeSuite.scala | 3 +- python/pyspark/mllib/_common.py | 33 ++- python/pyspark/mllib/tests.py | 36 +++ python/pyspark/mllib/tree.py | 225 ++++++++++++++++++ python/pyspark/mllib/util.py | 14 +- python/run-tests | 1 + 10 files changed, 509 insertions(+), 21 deletions(-) create mode 100755 examples/src/main/python/mllib/decision_tree_runner.py create mode 100644 python/pyspark/mllib/tree.py diff --git a/examples/src/main/python/mllib/decision_tree_runner.py b/examples/src/main/python/mllib/decision_tree_runner.py new file mode 100755 index 0000000000000..8efadb5223f56 --- /dev/null +++ b/examples/src/main/python/mllib/decision_tree_runner.py @@ -0,0 +1,133 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Decision tree classification and regression using MLlib. +""" + +import numpy, os, sys + +from operator import add + +from pyspark import SparkContext +from pyspark.mllib.regression import LabeledPoint +from pyspark.mllib.tree import DecisionTree +from pyspark.mllib.util import MLUtils + + +def getAccuracy(dtModel, data): + """ + Return accuracy of DecisionTreeModel on the given RDD[LabeledPoint]. + """ + seqOp = (lambda acc, x: acc + (x[0] == x[1])) + predictions = dtModel.predict(data.map(lambda x: x.features)) + truth = data.map(lambda p: p.label) + trainCorrect = predictions.zip(truth).aggregate(0, seqOp, add) + if data.count() == 0: + return 0 + return trainCorrect / (0.0 + data.count()) + + +def getMSE(dtModel, data): + """ + Return mean squared error (MSE) of DecisionTreeModel on the given + RDD[LabeledPoint]. + """ + seqOp = (lambda acc, x: acc + numpy.square(x[0] - x[1])) + predictions = dtModel.predict(data.map(lambda x: x.features)) + truth = data.map(lambda p: p.label) + trainMSE = predictions.zip(truth).aggregate(0, seqOp, add) + if data.count() == 0: + return 0 + return trainMSE / (0.0 + data.count()) + + +def reindexClassLabels(data): + """ + Re-index class labels in a dataset to the range {0,...,numClasses-1}. + If all labels in that range already appear at least once, + then the returned RDD is the same one (without a mapping). + Note: If a label simply does not appear in the data, + the index will not include it. + Be aware of this when reindexing subsampled data. + :param data: RDD of LabeledPoint where labels are integer values + denoting labels for a classification problem. + :return: Pair (reindexedData, origToNewLabels) where + reindexedData is an RDD of LabeledPoint with labels in + the range {0,...,numClasses-1}, and + origToNewLabels is a dictionary mapping original labels + to new labels. + """ + # classCounts: class --> # examples in class + classCounts = data.map(lambda x: x.label).countByValue() + numExamples = sum(classCounts.values()) + sortedClasses = sorted(classCounts.keys()) + numClasses = len(classCounts) + # origToNewLabels: class --> index in 0,...,numClasses-1 + if (numClasses < 2): + print >> sys.stderr, \ + "Dataset for classification should have at least 2 classes." + \ + " The given dataset had only %d classes." % numClasses + exit(1) + origToNewLabels = dict([(sortedClasses[i], i) for i in range(0, numClasses)]) + + print "numClasses = %d" % numClasses + print "Per-class example fractions, counts:" + print "Class\tFrac\tCount" + for c in sortedClasses: + frac = classCounts[c] / (numExamples + 0.0) + print "%g\t%g\t%d" % (c, frac, classCounts[c]) + + if (sortedClasses[0] == 0 and sortedClasses[-1] == numClasses - 1): + return (data, origToNewLabels) + else: + reindexedData = \ + data.map(lambda x: LabeledPoint(origToNewLabels[x.label], x.features)) + return (reindexedData, origToNewLabels) + + +def usage(): + print >> sys.stderr, \ + "Usage: decision_tree_runner [libsvm format data filepath]\n" + \ + " Note: This only supports binary classification." + exit(1) + + +if __name__ == "__main__": + if len(sys.argv) > 2: + usage() + sc = SparkContext(appName="PythonDT") + + # Load data. + dataPath = 'data/mllib/sample_libsvm_data.txt' + if len(sys.argv) == 2: + dataPath = sys.argv[1] + if not os.path.isfile(dataPath): + usage() + points = MLUtils.loadLibSVMFile(sc, dataPath) + + # Re-index class labels if needed. + (reindexedData, origToNewLabels) = reindexClassLabels(points) + + # Train a classifier. + model = DecisionTree.trainClassifier(reindexedData, numClasses=2) + # Print learned tree and stats. + print "Trained DecisionTree for classification:" + print " Model numNodes: %d\n" % model.numNodes() + print " Model depth: %d\n" % model.depth() + print " Training accuracy: %g\n" % getAccuracy(model, reindexedData) + print model diff --git a/examples/src/main/python/mllib/logistic_regression.py b/examples/src/main/python/mllib/logistic_regression.py index 6e0f7a4ee5a81..9d547ff77c984 100755 --- a/examples/src/main/python/mllib/logistic_regression.py +++ b/examples/src/main/python/mllib/logistic_regression.py @@ -30,8 +30,10 @@ from pyspark.mllib.classification import LogisticRegressionWithSGD -# Parse a line of text into an MLlib LabeledPoint object def parsePoint(line): + """ + Parse a line of text into an MLlib LabeledPoint object. + """ values = [float(s) for s in line.split(' ')] if values[0] == -1: # Convert -1 labels to 0 for MLlib values[0] = 0 diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 7d912737b8f0b..1d5d3762ed8e9 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -19,6 +19,8 @@ package org.apache.spark.mllib.api.python import java.nio.{ByteBuffer, ByteOrder} +import scala.collection.JavaConverters._ + import org.apache.spark.annotation.DeveloperApi import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} import org.apache.spark.mllib.classification._ @@ -29,6 +31,11 @@ import org.apache.spark.mllib.linalg.{Matrix, SparseVector, Vector, Vectors} import org.apache.spark.mllib.random.{RandomRDDGenerators => RG} import org.apache.spark.mllib.recommendation._ import org.apache.spark.mllib.regression._ +import org.apache.spark.mllib.tree.configuration.Algo._ +import org.apache.spark.mllib.tree.configuration.Strategy +import org.apache.spark.mllib.tree.DecisionTree +import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Impurity, Variance} +import org.apache.spark.mllib.tree.model.DecisionTreeModel import org.apache.spark.mllib.stat.Statistics import org.apache.spark.mllib.stat.correlation.CorrelationNames import org.apache.spark.mllib.util.MLUtils @@ -472,6 +479,76 @@ class PythonMLLibAPI extends Serializable { ALS.trainImplicit(ratings, rank, iterations, lambda, blocks, alpha) } + /** + * Java stub for Python mllib DecisionTree.train(). + * This stub returns a handle to the Java object instead of the content of the Java object. + * Extra care needs to be taken in the Python code to ensure it gets freed on exit; + * see the Py4J documentation. + * @param dataBytesJRDD Training data + * @param categoricalFeaturesInfoJMap Categorical features info, as Java map + */ + def trainDecisionTreeModel( + dataBytesJRDD: JavaRDD[Array[Byte]], + algoStr: String, + numClasses: Int, + categoricalFeaturesInfoJMap: java.util.Map[Int, Int], + impurityStr: String, + maxDepth: Int, + maxBins: Int): DecisionTreeModel = { + + val data = dataBytesJRDD.rdd.map(deserializeLabeledPoint) + + val algo: Algo = algoStr match { + case "classification" => Classification + case "regression" => Regression + case _ => throw new IllegalArgumentException(s"Bad algoStr parameter: $algoStr") + } + val impurity: Impurity = impurityStr match { + case "gini" => Gini + case "entropy" => Entropy + case "variance" => Variance + case _ => throw new IllegalArgumentException(s"Bad impurityStr parameter: $impurityStr") + } + + val strategy = new Strategy( + algo = algo, + impurity = impurity, + maxDepth = maxDepth, + numClassesForClassification = numClasses, + maxBins = maxBins, + categoricalFeaturesInfo = categoricalFeaturesInfoJMap.asScala.toMap) + + DecisionTree.train(data, strategy) + } + + /** + * Predict the label of the given data point. + * This is a Java stub for python DecisionTreeModel.predict() + * + * @param featuresBytes Serialized feature vector for data point + * @return predicted label + */ + def predictDecisionTreeModel( + model: DecisionTreeModel, + featuresBytes: Array[Byte]): Double = { + val features: Vector = deserializeDoubleVector(featuresBytes) + model.predict(features) + } + + /** + * Predict the labels of the given data points. + * This is a Java stub for python DecisionTreeModel.predict() + * + * @param dataJRDD A JavaRDD with serialized feature vectors + * @return JavaRDD of serialized predictions + */ + def predictDecisionTreeModel( + model: DecisionTreeModel, + dataJRDD: JavaRDD[Array[Byte]]): JavaRDD[Array[Byte]] = { + val data = dataJRDD.rdd.map(xBytes => deserializeDoubleVector(xBytes)) + model.predict(data).map(serializeDouble) + } + /** * Java stub for mllib Statistics.corr(X: RDD[Vector], method: String). * Returns the correlation matrix serialized into a byte array understood by deserializers in @@ -597,4 +674,5 @@ class PythonMLLibAPI extends Serializable { val s = getSeedOrDefault(seed) RG.poissonVectorRDD(jsc.sc, mean, numRows, numCols, parts, s).map(serializeDoubleVector) } + } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index 5c65b537b6867..fdad4f029aa99 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -56,7 +56,8 @@ class Strategy ( if (algo == Classification) { require(numClassesForClassification >= 2) } - val isMulticlassClassification = numClassesForClassification > 2 + val isMulticlassClassification = + algo == Classification && numClassesForClassification > 2 val isMulticlassWithCategoricalFeatures = isMulticlassClassification && (categoricalFeaturesInfo.size > 0) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 546a132559326..8665a00f3b356 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -48,7 +48,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { requiredMSE: Double) { val predictions = input.map(x => model.predict(x.features)) val squaredError = predictions.zip(input).map { case (prediction, expected) => - (prediction - expected.label) * (prediction - expected.label) + val err = prediction - expected.label + err * err }.sum val mse = squaredError / input.length assert(mse <= requiredMSE) diff --git a/python/pyspark/mllib/_common.py b/python/pyspark/mllib/_common.py index c6ca6a75df746..9c1565affbdac 100644 --- a/python/pyspark/mllib/_common.py +++ b/python/pyspark/mllib/_common.py @@ -343,22 +343,35 @@ def _copyto(array, buffer, offset, shape, dtype): temp_array[...] = array -def _get_unmangled_rdd(data, serializer): +def _get_unmangled_rdd(data, serializer, cache=True): + """ + :param cache: If True, the serialized RDD is cached. (default = True) + WARNING: Users should unpersist() this later! + """ dataBytes = data.map(serializer) dataBytes._bypass_serializer = True - dataBytes.cache() # TODO: users should unpersist() this later! + if cache: + dataBytes.cache() return dataBytes -# Map a pickled Python RDD of Python dense or sparse vectors to a Java RDD of -# _serialized_double_vectors -def _get_unmangled_double_vector_rdd(data): - return _get_unmangled_rdd(data, _serialize_double_vector) +def _get_unmangled_double_vector_rdd(data, cache=True): + """ + Map a pickled Python RDD of Python dense or sparse vectors to a Java RDD of + _serialized_double_vectors. + :param cache: If True, the serialized RDD is cached. (default = True) + WARNING: Users should unpersist() this later! + """ + return _get_unmangled_rdd(data, _serialize_double_vector, cache) -# Map a pickled Python RDD of LabeledPoint to a Java RDD of _serialized_labeled_points -def _get_unmangled_labeled_point_rdd(data): - return _get_unmangled_rdd(data, _serialize_labeled_point) +def _get_unmangled_labeled_point_rdd(data, cache=True): + """ + Map a pickled Python RDD of LabeledPoint to a Java RDD of _serialized_labeled_points. + :param cache: If True, the serialized RDD is cached. (default = True) + WARNING: Users should unpersist() this later! + """ + return _get_unmangled_rdd(data, _serialize_labeled_point, cache) # Common functions for dealing with and training linear models @@ -380,7 +393,7 @@ def _linear_predictor_typecheck(x, coeffs): if x.size != coeffs.shape[0]: raise RuntimeError("Got sparse vector of size %d; wanted %d" % ( x.size, coeffs.shape[0])) - elif (type(x) == RDD): + elif isinstance(x, RDD): raise RuntimeError("Bulk predict not yet supported.") else: raise TypeError("Argument of type " + type(x).__name__ + " unsupported") diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 37ccf1d590743..9d1e5be637a9a 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -100,6 +100,7 @@ def test_clustering(self): def test_classification(self): from pyspark.mllib.classification import LogisticRegressionWithSGD, SVMWithSGD, NaiveBayes + from pyspark.mllib.tree import DecisionTree data = [ LabeledPoint(0.0, [1, 0, 0]), LabeledPoint(1.0, [0, 1, 1]), @@ -127,9 +128,19 @@ def test_classification(self): self.assertTrue(nb_model.predict(features[2]) <= 0) self.assertTrue(nb_model.predict(features[3]) > 0) + categoricalFeaturesInfo = {0: 3} # feature 0 has 3 categories + dt_model = \ + DecisionTree.trainClassifier(rdd, numClasses=2, + categoricalFeaturesInfo=categoricalFeaturesInfo) + self.assertTrue(dt_model.predict(features[0]) <= 0) + self.assertTrue(dt_model.predict(features[1]) > 0) + self.assertTrue(dt_model.predict(features[2]) <= 0) + self.assertTrue(dt_model.predict(features[3]) > 0) + def test_regression(self): from pyspark.mllib.regression import LinearRegressionWithSGD, LassoWithSGD, \ RidgeRegressionWithSGD + from pyspark.mllib.tree import DecisionTree data = [ LabeledPoint(-1.0, [0, -1]), LabeledPoint(1.0, [0, 1]), @@ -157,6 +168,14 @@ def test_regression(self): self.assertTrue(rr_model.predict(features[2]) <= 0) self.assertTrue(rr_model.predict(features[3]) > 0) + categoricalFeaturesInfo = {0: 2} # feature 0 has 2 categories + dt_model = \ + DecisionTree.trainRegressor(rdd, categoricalFeaturesInfo=categoricalFeaturesInfo) + self.assertTrue(dt_model.predict(features[0]) <= 0) + self.assertTrue(dt_model.predict(features[1]) > 0) + self.assertTrue(dt_model.predict(features[2]) <= 0) + self.assertTrue(dt_model.predict(features[3]) > 0) + @unittest.skipIf(not _have_scipy, "SciPy not installed") class SciPyTests(PySparkTestCase): @@ -229,6 +248,7 @@ def test_clustering(self): def test_classification(self): from pyspark.mllib.classification import LogisticRegressionWithSGD, SVMWithSGD, NaiveBayes + from pyspark.mllib.tree import DecisionTree data = [ LabeledPoint(0.0, self.scipy_matrix(2, {0: 1.0})), LabeledPoint(1.0, self.scipy_matrix(2, {1: 1.0})), @@ -256,9 +276,18 @@ def test_classification(self): self.assertTrue(nb_model.predict(features[2]) <= 0) self.assertTrue(nb_model.predict(features[3]) > 0) + categoricalFeaturesInfo = {0: 3} # feature 0 has 3 categories + dt_model = DecisionTree.trainClassifier(rdd, numClasses=2, + categoricalFeaturesInfo=categoricalFeaturesInfo) + self.assertTrue(dt_model.predict(features[0]) <= 0) + self.assertTrue(dt_model.predict(features[1]) > 0) + self.assertTrue(dt_model.predict(features[2]) <= 0) + self.assertTrue(dt_model.predict(features[3]) > 0) + def test_regression(self): from pyspark.mllib.regression import LinearRegressionWithSGD, LassoWithSGD, \ RidgeRegressionWithSGD + from pyspark.mllib.tree import DecisionTree data = [ LabeledPoint(-1.0, self.scipy_matrix(2, {1: -1.0})), LabeledPoint(1.0, self.scipy_matrix(2, {1: 1.0})), @@ -286,6 +315,13 @@ def test_regression(self): self.assertTrue(rr_model.predict(features[2]) <= 0) self.assertTrue(rr_model.predict(features[3]) > 0) + categoricalFeaturesInfo = {0: 2} # feature 0 has 2 categories + dt_model = DecisionTree.trainRegressor(rdd, categoricalFeaturesInfo=categoricalFeaturesInfo) + self.assertTrue(dt_model.predict(features[0]) <= 0) + self.assertTrue(dt_model.predict(features[1]) > 0) + self.assertTrue(dt_model.predict(features[2]) <= 0) + self.assertTrue(dt_model.predict(features[3]) > 0) + if __name__ == "__main__": if not _have_scipy: diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py new file mode 100644 index 0000000000000..1e0006df75ac6 --- /dev/null +++ b/python/pyspark/mllib/tree.py @@ -0,0 +1,225 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from py4j.java_collections import MapConverter + +from pyspark import SparkContext, RDD +from pyspark.mllib._common import \ + _get_unmangled_rdd, _get_unmangled_double_vector_rdd, _serialize_double_vector, \ + _deserialize_labeled_point, _get_unmangled_labeled_point_rdd, \ + _deserialize_double +from pyspark.mllib.regression import LabeledPoint +from pyspark.serializers import NoOpSerializer + +class DecisionTreeModel(object): + """ + A decision tree model for classification or regression. + + EXPERIMENTAL: This is an experimental API. + It will probably be modified for Spark v1.2. + """ + + def __init__(self, sc, java_model): + """ + :param sc: Spark context + :param java_model: Handle to Java model object + """ + self._sc = sc + self._java_model = java_model + + def __del__(self): + self._sc._gateway.detach(self._java_model) + + def predict(self, x): + """ + Predict the label of one or more examples. + :param x: Data point (feature vector), + or an RDD of data points (feature vectors). + """ + pythonAPI = self._sc._jvm.PythonMLLibAPI() + if isinstance(x, RDD): + # Bulk prediction + if x.count() == 0: + return self._sc.parallelize([]) + dataBytes = _get_unmangled_double_vector_rdd(x, cache=False) + jSerializedPreds = \ + pythonAPI.predictDecisionTreeModel(self._java_model, + dataBytes._jrdd) + serializedPreds = RDD(jSerializedPreds, self._sc, NoOpSerializer()) + return serializedPreds.map(lambda bytes: _deserialize_double(bytearray(bytes))) + else: + # Assume x is a single data point. + x_ = _serialize_double_vector(x) + return pythonAPI.predictDecisionTreeModel(self._java_model, x_) + + def numNodes(self): + return self._java_model.numNodes() + + def depth(self): + return self._java_model.depth() + + def __str__(self): + return self._java_model.toString() + + +class DecisionTree(object): + """ + Learning algorithm for a decision tree model + for classification or regression. + + EXPERIMENTAL: This is an experimental API. + It will probably be modified for Spark v1.2. + + Example usage: + >>> from numpy import array, ndarray + >>> from pyspark.mllib.regression import LabeledPoint + >>> from pyspark.mllib.tree import DecisionTree + >>> from pyspark.mllib.linalg import SparseVector + >>> + >>> data = [ + ... LabeledPoint(0.0, [0.0]), + ... LabeledPoint(1.0, [1.0]), + ... LabeledPoint(1.0, [2.0]), + ... LabeledPoint(1.0, [3.0]) + ... ] + >>> + >>> model = DecisionTree.trainClassifier(sc.parallelize(data), numClasses=2) + >>> print(model) + DecisionTreeModel classifier + If (feature 0 <= 0.5) + Predict: 0.0 + Else (feature 0 > 0.5) + Predict: 1.0 + + >>> model.predict(array([1.0])) > 0 + True + >>> model.predict(array([0.0])) == 0 + True + >>> sparse_data = [ + ... LabeledPoint(0.0, SparseVector(2, {0: 0.0})), + ... LabeledPoint(1.0, SparseVector(2, {1: 1.0})), + ... LabeledPoint(0.0, SparseVector(2, {0: 0.0})), + ... LabeledPoint(1.0, SparseVector(2, {1: 2.0})) + ... ] + >>> + >>> model = DecisionTree.trainRegressor(sc.parallelize(sparse_data)) + >>> model.predict(array([0.0, 1.0])) == 1 + True + >>> model.predict(array([0.0, 0.0])) == 0 + True + >>> model.predict(SparseVector(2, {1: 1.0})) == 1 + True + >>> model.predict(SparseVector(2, {1: 0.0})) == 0 + True + """ + + @staticmethod + def trainClassifier(data, numClasses, categoricalFeaturesInfo={}, + impurity="gini", maxDepth=4, maxBins=100): + """ + Train a DecisionTreeModel for classification. + + :param data: Training data: RDD of LabeledPoint. + Labels are integers {0,1,...,numClasses}. + :param numClasses: Number of classes for classification. + :param categoricalFeaturesInfo: Map from categorical feature index + to number of categories. + Any feature not in this map + is treated as continuous. + :param impurity: Supported values: "entropy" or "gini" + :param maxDepth: Max depth of tree. + E.g., depth 0 means 1 leaf node. + Depth 1 means 1 internal node + 2 leaf nodes. + :param maxBins: Number of bins used for finding splits at each node. + :return: DecisionTreeModel + """ + return DecisionTree.train(data, "classification", numClasses, + categoricalFeaturesInfo, + impurity, maxDepth, maxBins) + + @staticmethod + def trainRegressor(data, categoricalFeaturesInfo={}, + impurity="variance", maxDepth=4, maxBins=100): + """ + Train a DecisionTreeModel for regression. + + :param data: Training data: RDD of LabeledPoint. + Labels are real numbers. + :param categoricalFeaturesInfo: Map from categorical feature index + to number of categories. + Any feature not in this map + is treated as continuous. + :param impurity: Supported values: "variance" + :param maxDepth: Max depth of tree. + E.g., depth 0 means 1 leaf node. + Depth 1 means 1 internal node + 2 leaf nodes. + :param maxBins: Number of bins used for finding splits at each node. + :return: DecisionTreeModel + """ + return DecisionTree.train(data, "regression", 0, + categoricalFeaturesInfo, + impurity, maxDepth, maxBins) + + + @staticmethod + def train(data, algo, numClasses, categoricalFeaturesInfo, + impurity, maxDepth, maxBins=100): + """ + Train a DecisionTreeModel for classification or regression. + + :param data: Training data: RDD of LabeledPoint. + For classification, labels are integers + {0,1,...,numClasses}. + For regression, labels are real numbers. + :param algo: "classification" or "regression" + :param numClasses: Number of classes for classification. + :param categoricalFeaturesInfo: Map from categorical feature index + to number of categories. + Any feature not in this map + is treated as continuous. + :param impurity: For classification: "entropy" or "gini". + For regression: "variance". + :param maxDepth: Max depth of tree. + E.g., depth 0 means 1 leaf node. + Depth 1 means 1 internal node + 2 leaf nodes. + :param maxBins: Number of bins used for finding splits at each node. + :return: DecisionTreeModel + """ + sc = data.context + dataBytes = _get_unmangled_labeled_point_rdd(data) + categoricalFeaturesInfoJMap = \ + MapConverter().convert(categoricalFeaturesInfo, + sc._gateway._gateway_client) + model = sc._jvm.PythonMLLibAPI().trainDecisionTreeModel( + dataBytes._jrdd, algo, + numClasses, categoricalFeaturesInfoJMap, + impurity, maxDepth, maxBins) + dataBytes.unpersist() + return DecisionTreeModel(sc, model) + + +def _test(): + import doctest + globs = globals().copy() + globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) + (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) + globs['sc'].stop() + if failure_count: + exit(-1) + +if __name__ == "__main__": + _test() diff --git a/python/pyspark/mllib/util.py b/python/pyspark/mllib/util.py index d94900cefdb77..639cda6350229 100644 --- a/python/pyspark/mllib/util.py +++ b/python/pyspark/mllib/util.py @@ -16,6 +16,7 @@ # import numpy as np +import warnings from pyspark.mllib.linalg import Vectors, SparseVector from pyspark.mllib.regression import LabeledPoint @@ -29,9 +30,9 @@ class MLUtils: Helper methods to load, save and pre-process data used in MLlib. """ - @deprecated @staticmethod def _parse_libsvm_line(line, multiclass): + warnings.warn("deprecated", DeprecationWarning) return _parse_libsvm_line(line) @staticmethod @@ -67,9 +68,9 @@ def _convert_labeled_point_to_libsvm(p): " but got " % type(v)) return " ".join(items) - @deprecated @staticmethod def loadLibSVMFile(sc, path, multiclass=False, numFeatures=-1, minPartitions=None): + warnings.warn("deprecated", DeprecationWarning) return loadLibSVMFile(sc, path, numFeatures, minPartitions) @staticmethod @@ -106,7 +107,6 @@ def loadLibSVMFile(sc, path, numFeatures=-1, minPartitions=None): >>> tempFile.write("+1 1:1.0 3:2.0 5:3.0\\n-1\\n-1 2:4.0 4:5.0 6:6.0") >>> tempFile.flush() >>> examples = MLUtils.loadLibSVMFile(sc, tempFile.name).collect() - >>> multiclass_examples = MLUtils.loadLibSVMFile(sc, tempFile.name).collect() >>> tempFile.close() >>> type(examples[0]) == LabeledPoint True @@ -115,20 +115,18 @@ def loadLibSVMFile(sc, path, numFeatures=-1, minPartitions=None): >>> type(examples[1]) == LabeledPoint True >>> print examples[1] - (0.0,(6,[],[])) + (-1.0,(6,[],[])) >>> type(examples[2]) == LabeledPoint True >>> print examples[2] - (0.0,(6,[1,3,5],[4.0,5.0,6.0])) - >>> multiclass_examples[1].label - -1.0 + (-1.0,(6,[1,3,5],[4.0,5.0,6.0])) """ lines = sc.textFile(path, minPartitions) parsed = lines.map(lambda l: MLUtils._parse_libsvm_line(l)) if numFeatures <= 0: parsed.cache() - numFeatures = parsed.map(lambda x: 0 if x[1].size == 0 else x[1][-1]).reduce(max) + 1 + numFeatures = parsed.map(lambda x: -1 if x[1].size == 0 else x[1][-1]).reduce(max) + 1 return parsed.map(lambda x: LabeledPoint(x[0], Vectors.sparse(numFeatures, x[1], x[2]))) @staticmethod diff --git a/python/run-tests b/python/run-tests index 5049e15ce5f8a..48feba2f5bd63 100755 --- a/python/run-tests +++ b/python/run-tests @@ -71,6 +71,7 @@ run_test "pyspark/mllib/random.py" run_test "pyspark/mllib/recommendation.py" run_test "pyspark/mllib/regression.py" run_test "pyspark/mllib/tests.py" +run_test "pyspark/mllib/util.py" if [[ $FAILED == 0 ]]; then echo -en "\033[32m" # Green From 91de0dc1654d609dc1ff8fa9a07ba18043ad61c6 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Sat, 2 Aug 2014 13:16:41 -0700 Subject: [PATCH 003/231] [SQL] Set outputPartitioning of BroadcastHashJoin correctly. I think we will not generate the plan triggering this bug at this moment. But, let me explain it... Right now, we are using `left.outputPartitioning` as the `outputPartitioning` of a `BroadcastHashJoin`. We may have a wrong physical plan for cases like... ```sql SELECT l.key, count(*) FROM (SELECT key, count(*) as cnt FROM src GROUP BY key) l // This is buildPlan JOIN r // This is the streamedPlan ON (l.cnt = r.value) GROUP BY l.key ``` Let's say we have a `BroadcastHashJoin` on `l` and `r`. For this case, we will pick `l`'s `outputPartitioning` for the `outputPartitioning`of the `BroadcastHashJoin` on `l` and `r`. Also, because the last `GROUP BY` is using `l.key` as the key, we will not introduce an `Exchange` for this aggregation. However, `r`'s outputPartitioning may not match the required distribution of the last `GROUP BY` and we fail to group data correctly. JIRA is being reindexed. I will create a JIRA ticket once it is back online. Author: Yin Huai Closes #1735 from yhuai/BroadcastHashJoin and squashes the following commits: 96d9cb3 [Yin Huai] Set outputPartitioning correctly. (cherry picked from commit 67bd8e3c217a80c3117a6e3853aa60fe13d08c91) Signed-off-by: Michael Armbrust --- .../src/main/scala/org/apache/spark/sql/execution/joins.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala index cc138c749949d..51bb61530744c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala @@ -405,8 +405,7 @@ case class BroadcastHashJoin( left: SparkPlan, right: SparkPlan) extends BinaryNode with HashJoin { - - override def outputPartitioning: Partitioning = left.outputPartitioning + override def outputPartitioning: Partitioning = streamedPlan.outputPartitioning override def requiredChildDistribution = UnspecifiedDistribution :: UnspecifiedDistribution :: Nil From bb0ac6d7c91c491a99c252e6cb4aea40efe9b190 Mon Sep 17 00:00:00 2001 From: Chris Fregly Date: Sat, 2 Aug 2014 13:35:35 -0700 Subject: [PATCH 004/231] [SPARK-1981] Add AWS Kinesis streaming support Author: Chris Fregly Closes #1434 from cfregly/master and squashes the following commits: 4774581 [Chris Fregly] updated docs, renamed retry to retryRandom to be more clear, removed retries around store() method 0393795 [Chris Fregly] moved Kinesis examples out of examples/ and back into extras/kinesis-asl 691a6be [Chris Fregly] fixed tests and formatting, fixed a bug with JavaKinesisWordCount during union of streams 0e1c67b [Chris Fregly] Merge remote-tracking branch 'upstream/master' 74e5c7c [Chris Fregly] updated per TD's feedback. simplified examples, updated docs e33cbeb [Chris Fregly] Merge remote-tracking branch 'upstream/master' bf614e9 [Chris Fregly] per matei's feedback: moved the kinesis examples into the examples/ dir d17ca6d [Chris Fregly] per TD's feedback: updated docs, simplified the KinesisUtils api 912640c [Chris Fregly] changed the foundKinesis class to be a publically-avail class db3eefd [Chris Fregly] Merge remote-tracking branch 'upstream/master' 21de67f [Chris Fregly] Merge remote-tracking branch 'upstream/master' 6c39561 [Chris Fregly] parameterized the versions of the aws java sdk and kinesis client 338997e [Chris Fregly] improve build docs for kinesis 828f8ae [Chris Fregly] more cleanup e7c8978 [Chris Fregly] Merge remote-tracking branch 'upstream/master' cd68c0d [Chris Fregly] fixed typos and backward compatibility d18e680 [Chris Fregly] Merge remote-tracking branch 'upstream/master' b3b0ff1 [Chris Fregly] [SPARK-1981] Add AWS Kinesis streaming support (cherry picked from commit 91f9504e6086fac05b40545099f9818949c24bca) Signed-off-by: Tathagata Das --- bin/run-example | 3 +- bin/run-example2.cmd | 3 +- dev/audit-release/audit_release.py | 4 +- .../src/main/scala/SparkApp.scala | 7 + dev/audit-release/sbt_app_kinesis/build.sbt | 28 ++ .../src/main/scala/SparkApp.scala | 33 +++ dev/create-release/create-release.sh | 4 +- dev/run-tests | 3 + docs/streaming-custom-receivers.md | 4 +- docs/streaming-kinesis.md | 58 ++++ docs/streaming-programming-guide.md | 12 +- examples/pom.xml | 13 + extras/kinesis-asl/pom.xml | 96 ++++++ .../streaming/JavaKinesisWordCountASL.java | 180 ++++++++++++ .../src/main/resources/log4j.properties | 37 +++ .../streaming/KinesisWordCountASL.scala | 251 ++++++++++++++++ .../kinesis/KinesisCheckpointState.scala | 56 ++++ .../streaming/kinesis/KinesisReceiver.scala | 149 ++++++++++ .../kinesis/KinesisRecordProcessor.scala | 212 ++++++++++++++ .../streaming/kinesis/KinesisUtils.scala | 96 ++++++ .../kinesis/JavaKinesisStreamSuite.java | 41 +++ .../src/test/resources/log4j.properties | 26 ++ .../kinesis/KinesisReceiverSuite.scala | 275 ++++++++++++++++++ pom.xml | 10 + project/SparkBuild.scala | 6 +- 25 files changed, 1592 insertions(+), 15 deletions(-) create mode 100644 dev/audit-release/sbt_app_kinesis/build.sbt create mode 100644 dev/audit-release/sbt_app_kinesis/src/main/scala/SparkApp.scala create mode 100644 docs/streaming-kinesis.md create mode 100644 extras/kinesis-asl/pom.xml create mode 100644 extras/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java create mode 100644 extras/kinesis-asl/src/main/resources/log4j.properties create mode 100644 extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala create mode 100644 extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointState.scala create mode 100644 extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala create mode 100644 extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala create mode 100644 extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala create mode 100644 extras/kinesis-asl/src/test/java/org/apache/spark/streaming/kinesis/JavaKinesisStreamSuite.java create mode 100644 extras/kinesis-asl/src/test/resources/log4j.properties create mode 100644 extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala diff --git a/bin/run-example b/bin/run-example index 942706d733122..68a35702eddd3 100755 --- a/bin/run-example +++ b/bin/run-example @@ -29,7 +29,8 @@ if [ -n "$1" ]; then else echo "Usage: ./bin/run-example [example-args]" 1>&2 echo " - set MASTER=XX to use a specific master" 1>&2 - echo " - can use abbreviated example class name (e.g. SparkPi, mllib.LinearRegression)" 1>&2 + echo " - can use abbreviated example class name relative to com.apache.spark.examples" 1>&2 + echo " (e.g. SparkPi, mllib.LinearRegression, streaming.KinesisWordCountASL)" 1>&2 exit 1 fi diff --git a/bin/run-example2.cmd b/bin/run-example2.cmd index eadedd7fa61ff..b29bf90c64e90 100644 --- a/bin/run-example2.cmd +++ b/bin/run-example2.cmd @@ -32,7 +32,8 @@ rem Test that an argument was given if not "x%1"=="x" goto arg_given echo Usage: run-example ^ [example-args] echo - set MASTER=XX to use a specific master - echo - can use abbreviated example class name (e.g. SparkPi, mllib.LinearRegression) + echo - can use abbreviated example class name relative to com.apache.spark.examples + echo (e.g. SparkPi, mllib.LinearRegression, streaming.KinesisWordCountASL) goto exit :arg_given diff --git a/dev/audit-release/audit_release.py b/dev/audit-release/audit_release.py index 230e900ecd4de..16ea1a71290dc 100755 --- a/dev/audit-release/audit_release.py +++ b/dev/audit-release/audit_release.py @@ -105,7 +105,7 @@ def get_url(url): "spark-core", "spark-bagel", "spark-mllib", "spark-streaming", "spark-repl", "spark-graphx", "spark-streaming-flume", "spark-streaming-kafka", "spark-streaming-mqtt", "spark-streaming-twitter", "spark-streaming-zeromq", - "spark-catalyst", "spark-sql", "spark-hive" + "spark-catalyst", "spark-sql", "spark-hive", "spark-streaming-kinesis-asl" ] modules = map(lambda m: "%s_%s" % (m, SCALA_BINARY_VERSION), modules) @@ -136,7 +136,7 @@ def ensure_path_not_present(x): os.chdir(original_dir) # SBT application tests -for app in ["sbt_app_core", "sbt_app_graphx", "sbt_app_streaming", "sbt_app_sql", "sbt_app_hive"]: +for app in ["sbt_app_core", "sbt_app_graphx", "sbt_app_streaming", "sbt_app_sql", "sbt_app_hive", "sbt_app_kinesis"]: os.chdir(app) ret = run_cmd("sbt clean run", exit_on_failure=False) test(ret == 0, "sbt application (%s)" % app) diff --git a/dev/audit-release/sbt_app_core/src/main/scala/SparkApp.scala b/dev/audit-release/sbt_app_core/src/main/scala/SparkApp.scala index 77bbd167b199a..fc03fec9866a6 100644 --- a/dev/audit-release/sbt_app_core/src/main/scala/SparkApp.scala +++ b/dev/audit-release/sbt_app_core/src/main/scala/SparkApp.scala @@ -50,5 +50,12 @@ object SimpleApp { println("Ganglia sink was loaded via spark-core") System.exit(-1) } + + // Remove kinesis from default build due to ASL license issue + val foundKinesis = Try(Class.forName("org.apache.spark.streaming.kinesis.KinesisUtils")).isSuccess + if (foundKinesis) { + println("Kinesis was loaded via spark-core") + System.exit(-1) + } } } diff --git a/dev/audit-release/sbt_app_kinesis/build.sbt b/dev/audit-release/sbt_app_kinesis/build.sbt new file mode 100644 index 0000000000000..981bc7957b5ed --- /dev/null +++ b/dev/audit-release/sbt_app_kinesis/build.sbt @@ -0,0 +1,28 @@ +// +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +name := "Kinesis Test" + +version := "1.0" + +scalaVersion := System.getenv.get("SCALA_VERSION") + +libraryDependencies += "org.apache.spark" %% "spark-streaming-kinesis-asl" % System.getenv.get("SPARK_VERSION") + +resolvers ++= Seq( + "Spark Release Repository" at System.getenv.get("SPARK_RELEASE_REPOSITORY"), + "Spray Repository" at "http://repo.spray.cc/") diff --git a/dev/audit-release/sbt_app_kinesis/src/main/scala/SparkApp.scala b/dev/audit-release/sbt_app_kinesis/src/main/scala/SparkApp.scala new file mode 100644 index 0000000000000..9f85066501472 --- /dev/null +++ b/dev/audit-release/sbt_app_kinesis/src/main/scala/SparkApp.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package main.scala + +import scala.util.Try + +import org.apache.spark.SparkContext +import org.apache.spark.SparkContext._ + +object SimpleApp { + def main(args: Array[String]) { + val foundKinesis = Try(Class.forName("org.apache.spark.streaming.kinesis.KinesisUtils")).isSuccess + if (!foundKinesis) { + println("Kinesis not loaded via kinesis-asl") + System.exit(-1) + } + } +} diff --git a/dev/create-release/create-release.sh b/dev/create-release/create-release.sh index af46572e6602b..42473629d4f15 100755 --- a/dev/create-release/create-release.sh +++ b/dev/create-release/create-release.sh @@ -53,15 +53,15 @@ if [[ ! "$@" =~ --package-only ]]; then -Dusername=$GIT_USERNAME -Dpassword=$GIT_PASSWORD \ -Dmaven.javadoc.skip=true \ -Dhadoop.version=2.2.0 -Dyarn.version=2.2.0 \ - -Pyarn -Phive -Phive-thriftserver -Phadoop-2.2 -Pspark-ganglia-lgpl\ -Dtag=$GIT_TAG -DautoVersionSubmodules=true \ + -Pyarn -Phive -Phive-thriftserver -Phadoop-2.2 -Pspark-ganglia-lgpl -Pkinesis-asl \ --batch-mode release:prepare mvn -DskipTests \ -Darguments="-DskipTests=true -Dmaven.javadoc.skip=true -Dhadoop.version=2.2.0 -Dyarn.version=2.2.0 -Dgpg.passphrase=${GPG_PASSPHRASE}" \ -Dhadoop.version=2.2.0 -Dyarn.version=2.2.0 \ -Dmaven.javadoc.skip=true \ - -Pyarn -Phive -Phive-thriftserver -Phadoop-2.2 -Pspark-ganglia-lgpl\ + -Pyarn -Phive -Phive-thriftserver -Phadoop-2.2 -Pspark-ganglia-lgpl -Pkinesis-asl \ release:perform cd .. diff --git a/dev/run-tests b/dev/run-tests index daa85bc750c07..d401c90f41d7b 100755 --- a/dev/run-tests +++ b/dev/run-tests @@ -36,6 +36,9 @@ fi if [ -z "$SBT_MAVEN_PROFILES_ARGS" ]; then export SBT_MAVEN_PROFILES_ARGS="-Pyarn -Phadoop-2.3 -Dhadoop.version=2.3.0" fi + +export SBT_MAVEN_PROFILES_ARGS="$SBT_MAVEN_PROFILES_ARGS -Pkinesis-asl" + echo "SBT_MAVEN_PROFILES_ARGS=\"$SBT_MAVEN_PROFILES_ARGS\"" # Remove work directory diff --git a/docs/streaming-custom-receivers.md b/docs/streaming-custom-receivers.md index a2dc3a8961dfc..1e045a3dd0ca9 100644 --- a/docs/streaming-custom-receivers.md +++ b/docs/streaming-custom-receivers.md @@ -4,7 +4,7 @@ title: Spark Streaming Custom Receivers --- Spark Streaming can receive streaming data from any arbitrary data source beyond -the one's for which it has in-built support (that is, beyond Flume, Kafka, files, sockets, etc.). +the one's for which it has in-built support (that is, beyond Flume, Kafka, Kinesis, files, sockets, etc.). This requires the developer to implement a *receiver* that is customized for receiving data from the concerned data source. This guide walks through the process of implementing a custom receiver and using it in a Spark Streaming application. @@ -174,7 +174,7 @@ val words = lines.flatMap(_.split(" ")) ... {% endhighlight %} -The full source code is in the example [CustomReceiver.scala](https://github.com/apache/spark/blob/master/examples/src/main/scala/org/apache/spark/streaming/examples/CustomReceiver.scala). +The full source code is in the example [CustomReceiver.scala](https://github.com/apache/spark/blob/master/examples/src/main/scala/org/apache/spark/examples/streaming/CustomReceiver.scala).
diff --git a/docs/streaming-kinesis.md b/docs/streaming-kinesis.md new file mode 100644 index 0000000000000..801c905c88df8 --- /dev/null +++ b/docs/streaming-kinesis.md @@ -0,0 +1,58 @@ +--- +layout: global +title: Spark Streaming Kinesis Receiver +--- + +### Kinesis +Build notes: +
  • Spark supports a Kinesis Streaming Receiver which is not included in the default build due to licensing restrictions.
  • +
  • _**Note that by embedding this library you will include [ASL](https://aws.amazon.com/asl/)-licensed code in your Spark package**_.
  • +
  • The Spark Kinesis Streaming Receiver source code, examples, tests, and artifacts live in $SPARK_HOME/extras/kinesis-asl.
  • +
  • To build with Kinesis, you must run the maven or sbt builds with -Pkinesis-asl`.
  • +
  • Applications will need to link to the 'spark-streaming-kinesis-asl` artifact.
  • + +Kinesis examples notes: +
  • To build the Kinesis examples, you must run the maven or sbt builds with -Pkinesis-asl`.
  • +
  • These examples automatically determine the number of local threads and KinesisReceivers to spin up based on the number of shards for the stream.
  • +
  • KinesisWordCountProducerASL will generate random data to put onto the Kinesis stream for testing.
  • +
  • Checkpointing is disabled (no checkpoint dir is set). The examples as written will not recover from a driver failure.
  • + +Deployment and runtime notes: +
  • A single KinesisReceiver can process many shards of a stream.
  • +
  • Each shard of a stream is processed by one or more KinesisReceiver's managed by the Kinesis Client Library (KCL) Worker.
  • +
  • You never need more KinesisReceivers than the number of shards in your stream.
  • +
  • You can horizontally scale the receiving by creating more KinesisReceiver/DStreams (up to the number of shards for a given stream)
  • +
  • The Kinesis libraries must be present on all worker nodes, as they will need access to the Kinesis Client Library.
  • +
  • This code uses the DefaultAWSCredentialsProviderChain and searches for credentials in the following order of precedence:
    + 1) Environment Variables - AWS_ACCESS_KEY_ID and AWS_SECRET_KEY
    + 2) Java System Properties - aws.accessKeyId and aws.secretKey
    + 3) Credential profiles file - default location (~/.aws/credentials) shared by all AWS SDKs
    + 4) Instance profile credentials - delivered through the Amazon EC2 metadata service
    +
  • +
  • You need to setup a Kinesis stream with 1 or more shards per the following:
    + http://docs.aws.amazon.com/kinesis/latest/dev/step-one-create-stream.html
  • +
  • Valid Kinesis endpoint urls can be found here: Valid endpoint urls: http://docs.aws.amazon.com/general/latest/gr/rande.html#ak_region
  • +
  • When you first start up the KinesisReceiver, the Kinesis Client Library (KCL) needs ~30s to establish connectivity with the AWS Kinesis service, +retrieve any checkpoint data, and negotiate with other KCL's reading from the same stream.
  • +
  • Be careful when changing the app name. Kinesis maintains a mapping table in DynamoDB based on this app name (http://docs.aws.amazon.com/kinesis/latest/dev/kinesis-record-processor-implementation-app.html#kinesis-record-processor-initialization). +Changing the app name could lead to Kinesis errors as only 1 logical application can process a stream. In order to start fresh, +it's always best to delete the DynamoDB table that matches your app name. This DynamoDB table lives in us-east-1 regardless of the Kinesis endpoint URL.
  • + +Failure recovery notes: +
  • The combination of Spark Streaming and Kinesis creates 3 different checkpoints as follows:
    + 1) RDD data checkpoint (Spark Streaming) - frequency is configurable with DStream.checkpoint(Duration)
    + 2) RDD metadata checkpoint (Spark Streaming) - frequency is every DStream batch
    + 3) Kinesis checkpointing (Kinesis) - frequency is controlled by the developer calling ICheckpointer.checkpoint() directly
    +
  • +
  • Checkpointing too frequently will cause excess load on the AWS checkpoint storage layer and may lead to AWS throttling
  • +
  • Upon startup, a KinesisReceiver will begin processing records with sequence numbers greater than the last checkpoint sequence number recorded per shard.
  • +
  • If no checkpoint info exists, the worker will start either from the oldest record available (InitialPositionInStream.TRIM_HORIZON) +or from the tip/latest (InitialPostitionInStream.LATEST). This is configurable.
  • +
  • When pulling from the stream tip (InitialPositionInStream.LATEST), only new stream data will be picked up after the KinesisReceiver starts.
  • +
  • InitialPositionInStream.LATEST could lead to missed records if data is added to the stream while no KinesisReceivers are running.
  • +
  • In production, you'll want to switch to InitialPositionInStream.TRIM_HORIZON which will read up to 24 hours (Kinesis limit) of previous stream data +depending on the checkpoint frequency.
  • +
  • InitialPositionInStream.TRIM_HORIZON may lead to duplicate processing of records depending on the checkpoint frequency.
  • +
  • Record processing should be idempotent when possible.
  • +
  • Failed or latent KinesisReceivers will be detected and automatically shutdown/load-balanced by the KCL.
  • +
  • If possible, explicitly shutdown the worker if a failure occurs in order to trigger the final checkpoint.
  • diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index 7b8b7933434c4..9f331ed50d2a4 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -9,7 +9,7 @@ title: Spark Streaming Programming Guide # Overview Spark Streaming is an extension of the core Spark API that allows enables high-throughput, fault-tolerant stream processing of live data streams. Data can be ingested from many sources -like Kafka, Flume, Twitter, ZeroMQ or plain old TCP sockets and be processed using complex +like Kafka, Flume, Twitter, ZeroMQ, Kinesis or plain old TCP sockets and be processed using complex algorithms expressed with high-level functions like `map`, `reduce`, `join` and `window`. Finally, processed data can be pushed out to filesystems, databases, and live dashboards. In fact, you can apply Spark's in-built @@ -38,7 +38,7 @@ stream of results in batches. Spark Streaming provides a high-level abstraction called *discretized stream* or *DStream*, which represents a continuous stream of data. DStreams can be created either from input data -stream from sources such as Kafka and Flume, or by applying high-level +stream from sources such as Kafka, Flume, and Kinesis, or by applying high-level operations on other DStreams. Internally, a DStream is represented as a sequence of [RDDs](api/scala/index.html#org.apache.spark.rdd.RDD). @@ -313,7 +313,7 @@ To write your own Spark Streaming program, you will have to add the following de artifactId = spark-streaming_{{site.SCALA_BINARY_VERSION}} version = {{site.SPARK_VERSION}} -For ingesting data from sources like Kafka and Flume that are not present in the Spark +For ingesting data from sources like Kafka, Flume, and Kinesis that are not present in the Spark Streaming core API, you will have to add the corresponding artifact `spark-streaming-xyz_{{site.SCALA_BINARY_VERSION}}` to the dependencies. For example, @@ -327,6 +327,7 @@ some of the common ones are as follows. Twitter spark-streaming-twitter_{{site.SCALA_BINARY_VERSION}} ZeroMQ spark-streaming-zeromq_{{site.SCALA_BINARY_VERSION}} MQTT spark-streaming-mqtt_{{site.SCALA_BINARY_VERSION}} + Kinesis
    (built separately) kinesis-asl_{{site.SCALA_BINARY_VERSION}} @@ -442,7 +443,7 @@ see the API documentations of the relevant functions in Scala and [JavaStreamingContext](api/scala/index.html#org.apache.spark.streaming.api.java.JavaStreamingContext) for Java. -Additional functionality for creating DStreams from sources such as Kafka, Flume, and Twitter +Additional functionality for creating DStreams from sources such as Kafka, Flume, Kinesis, and Twitter can be imported by adding the right dependencies as explained in an [earlier](#linking) section. To take the case of Kafka, after adding the artifact `spark-streaming-kafka_{{site.SCALA_BINARY_VERSION}}` to the @@ -467,6 +468,9 @@ For more details on these additional sources, see the corresponding [API documen Furthermore, you can also implement your own custom receiver for your sources. See the [Custom Receiver Guide](streaming-custom-receivers.html). +### Kinesis +[Kinesis](streaming-kinesis.html) + ## Operations There are two kinds of DStream operations - _transformations_ and _output operations_. Similar to RDD transformations, DStream transformations operate on one or more DStreams to create new DStreams diff --git a/examples/pom.xml b/examples/pom.xml index c4ed0f5a6a02b..8c4c128bb484d 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -34,6 +34,19 @@ Spark Project Examples http://spark.apache.org/ + + + kinesis-asl + + + org.apache.spark + spark-streaming-kinesis-asl_${scala.binary.version} + ${project.version} + + + + + org.apache.spark diff --git a/extras/kinesis-asl/pom.xml b/extras/kinesis-asl/pom.xml new file mode 100644 index 0000000000000..a54b34235dfb4 --- /dev/null +++ b/extras/kinesis-asl/pom.xml @@ -0,0 +1,96 @@ + + + + 4.0.0 + + org.apache.spark + spark-parent + 1.1.0-SNAPSHOT + ../../pom.xml + + + + org.apache.spark + spark-streaming-kinesis-asl_2.10 + jar + Spark Kinesis Integration + + + kinesis-asl + + + + + org.apache.spark + spark-streaming_${scala.binary.version} + ${project.version} + + + org.apache.spark + spark-streaming_${scala.binary.version} + ${project.version} + test-jar + test + + + com.amazonaws + amazon-kinesis-client + ${aws.kinesis.client.version} + + + com.amazonaws + aws-java-sdk + ${aws.java.sdk.version} + + + org.scalatest + scalatest_${scala.binary.version} + test + + + org.mockito + mockito-all + test + + + org.scalacheck + scalacheck_${scala.binary.version} + test + + + org.easymock + easymockclassextension + test + + + com.novocode + junit-interface + test + + + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + + org.scalatest + scalatest-maven-plugin + + + + diff --git a/extras/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java b/extras/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java new file mode 100644 index 0000000000000..a8b907b241893 --- /dev/null +++ b/extras/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java @@ -0,0 +1,180 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.examples.streaming; + +import java.util.ArrayList; +import java.util.List; +import java.util.regex.Pattern; + +import org.apache.log4j.Logger; +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.function.FlatMapFunction; +import org.apache.spark.api.java.function.Function2; +import org.apache.spark.api.java.function.PairFunction; +import org.apache.spark.storage.StorageLevel; +import org.apache.spark.streaming.Duration; +import org.apache.spark.streaming.api.java.JavaDStream; +import org.apache.spark.streaming.api.java.JavaPairDStream; +import org.apache.spark.streaming.api.java.JavaStreamingContext; +import org.apache.spark.streaming.kinesis.KinesisUtils; + +import scala.Tuple2; + +import com.amazonaws.auth.DefaultAWSCredentialsProviderChain; +import com.amazonaws.services.kinesis.AmazonKinesisClient; +import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream; +import com.google.common.collect.Lists; + +/** + * Java-friendly Kinesis Spark Streaming WordCount example + * + * See http://spark.apache.org/docs/latest/streaming-kinesis.html for more details + * on the Kinesis Spark Streaming integration. + * + * This example spins up 1 Kinesis Worker (Spark Streaming Receiver) per shard + * for the given stream. + * It then starts pulling from the last checkpointed sequence number of the given + * and . + * + * Valid endpoint urls: http://docs.aws.amazon.com/general/latest/gr/rande.html#ak_region + * + * This code uses the DefaultAWSCredentialsProviderChain and searches for credentials + * in the following order of precedence: + * Environment Variables - AWS_ACCESS_KEY_ID and AWS_SECRET_KEY + * Java System Properties - aws.accessKeyId and aws.secretKey + * Credential profiles file - default location (~/.aws/credentials) shared by all AWS SDKs + * Instance profile credentials - delivered through the Amazon EC2 metadata service + * + * Usage: JavaKinesisWordCountASL + * is the name of the Kinesis stream (ie. mySparkStream) + * is the endpoint of the Kinesis service + * (ie. https://kinesis.us-east-1.amazonaws.com) + * + * Example: + * $ export AWS_ACCESS_KEY_ID= + * $ export AWS_SECRET_KEY= + * $ $SPARK_HOME/bin/run-example \ + * org.apache.spark.examples.streaming.JavaKinesisWordCountASL mySparkStream \ + * https://kinesis.us-east-1.amazonaws.com + * + * There is a companion helper class called KinesisWordCountProducerASL which puts dummy data + * onto the Kinesis stream. + * Usage instructions for KinesisWordCountProducerASL are provided in the class definition. + */ +public final class JavaKinesisWordCountASL { + private static final Pattern WORD_SEPARATOR = Pattern.compile(" "); + private static final Logger logger = Logger.getLogger(JavaKinesisWordCountASL.class); + + /* Make the constructor private to enforce singleton */ + private JavaKinesisWordCountASL() { + } + + public static void main(String[] args) { + /* Check that all required args were passed in. */ + if (args.length < 2) { + System.err.println( + "|Usage: KinesisWordCount \n" + + "| is the name of the Kinesis stream\n" + + "| is the endpoint of the Kinesis service\n" + + "| (e.g. https://kinesis.us-east-1.amazonaws.com)\n"); + System.exit(1); + } + + StreamingExamples.setStreamingLogLevels(); + + /* Populate the appropriate variables from the given args */ + String streamName = args[0]; + String endpointUrl = args[1]; + /* Set the batch interval to a fixed 2000 millis (2 seconds) */ + Duration batchInterval = new Duration(2000); + + /* Create a Kinesis client in order to determine the number of shards for the given stream */ + AmazonKinesisClient kinesisClient = new AmazonKinesisClient( + new DefaultAWSCredentialsProviderChain()); + kinesisClient.setEndpoint(endpointUrl); + + /* Determine the number of shards from the stream */ + int numShards = kinesisClient.describeStream(streamName) + .getStreamDescription().getShards().size(); + + /* In this example, we're going to create 1 Kinesis Worker/Receiver/DStream for each shard */ + int numStreams = numShards; + + /* Must add 1 more thread than the number of receivers or the output won't show properly from the driver */ + int numSparkThreads = numStreams + 1; + + /* Setup the Spark config. */ + SparkConf sparkConfig = new SparkConf().setAppName("KinesisWordCount").setMaster( + "local[" + numSparkThreads + "]"); + + /* Kinesis checkpoint interval. Same as batchInterval for this example. */ + Duration checkpointInterval = batchInterval; + + /* Setup the StreamingContext */ + JavaStreamingContext jssc = new JavaStreamingContext(sparkConfig, batchInterval); + + /* Create the same number of Kinesis DStreams/Receivers as Kinesis stream's shards */ + List> streamsList = new ArrayList>(numStreams); + for (int i = 0; i < numStreams; i++) { + streamsList.add( + KinesisUtils.createStream(jssc, streamName, endpointUrl, checkpointInterval, + InitialPositionInStream.LATEST, StorageLevel.MEMORY_AND_DISK_2()) + ); + } + + /* Union all the streams if there is more than 1 stream */ + JavaDStream unionStreams; + if (streamsList.size() > 1) { + unionStreams = jssc.union(streamsList.get(0), streamsList.subList(1, streamsList.size())); + } else { + /* Otherwise, just use the 1 stream */ + unionStreams = streamsList.get(0); + } + + /* + * Split each line of the union'd DStreams into multiple words using flatMap to produce the collection. + * Convert lines of byte[] to multiple Strings by first converting to String, then splitting on WORD_SEPARATOR. + */ + JavaDStream words = unionStreams.flatMap(new FlatMapFunction() { + @Override + public Iterable call(byte[] line) { + return Lists.newArrayList(WORD_SEPARATOR.split(new String(line))); + } + }); + + /* Map each word to a (word, 1) tuple, then reduce/aggregate by word. */ + JavaPairDStream wordCounts = words.mapToPair( + new PairFunction() { + @Override + public Tuple2 call(String s) { + return new Tuple2(s, 1); + } + }).reduceByKey(new Function2() { + @Override + public Integer call(Integer i1, Integer i2) { + return i1 + i2; + } + }); + + /* Print the first 10 wordCounts */ + wordCounts.print(); + + /* Start the streaming context and await termination */ + jssc.start(); + jssc.awaitTermination(); + } +} diff --git a/extras/kinesis-asl/src/main/resources/log4j.properties b/extras/kinesis-asl/src/main/resources/log4j.properties new file mode 100644 index 0000000000000..97348fb5b6123 --- /dev/null +++ b/extras/kinesis-asl/src/main/resources/log4j.properties @@ -0,0 +1,37 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +log4j.rootCategory=WARN, console + +# File appender +log4j.appender.file=org.apache.log4j.FileAppender +log4j.appender.file.append=false +log4j.appender.file.file=target/unit-tests.log +log4j.appender.file.layout=org.apache.log4j.PatternLayout +log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %p %c{1}: %m%n + +# Console appender +log4j.appender.console=org.apache.log4j.ConsoleAppender +log4j.appender.console.target=System.out +log4j.appender.console.layout=org.apache.log4j.PatternLayout +log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n + +# Settings to quiet third party logs that are too verbose +log4j.logger.org.eclipse.jetty=WARN +log4j.logger.org.eclipse.jetty.util.component.AbstractLifeCycle=ERROR +log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=INFO +log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=INFO \ No newline at end of file diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala new file mode 100644 index 0000000000000..d03edf8b30a9f --- /dev/null +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala @@ -0,0 +1,251 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.streaming + +import java.nio.ByteBuffer +import scala.util.Random +import org.apache.spark.Logging +import org.apache.spark.SparkConf +import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming.Milliseconds +import org.apache.spark.streaming.StreamingContext +import org.apache.spark.streaming.StreamingContext.toPairDStreamFunctions +import org.apache.spark.streaming.kinesis.KinesisUtils +import com.amazonaws.auth.DefaultAWSCredentialsProviderChain +import com.amazonaws.services.kinesis.AmazonKinesisClient +import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream +import com.amazonaws.services.kinesis.model.PutRecordRequest +import org.apache.log4j.Logger +import org.apache.log4j.Level + +/** + * Kinesis Spark Streaming WordCount example. + * + * See http://spark.apache.org/docs/latest/streaming-kinesis.html for more details on + * the Kinesis Spark Streaming integration. + * + * This example spins up 1 Kinesis Worker (Spark Streaming Receiver) per shard + * for the given stream. + * It then starts pulling from the last checkpointed sequence number of the given + * and . + * + * Valid endpoint urls: http://docs.aws.amazon.com/general/latest/gr/rande.html#ak_region + * + * This code uses the DefaultAWSCredentialsProviderChain and searches for credentials + * in the following order of precedence: + * Environment Variables - AWS_ACCESS_KEY_ID and AWS_SECRET_KEY + * Java System Properties - aws.accessKeyId and aws.secretKey + * Credential profiles file - default location (~/.aws/credentials) shared by all AWS SDKs + * Instance profile credentials - delivered through the Amazon EC2 metadata service + * + * Usage: KinesisWordCountASL + * is the name of the Kinesis stream (ie. mySparkStream) + * is the endpoint of the Kinesis service + * (ie. https://kinesis.us-east-1.amazonaws.com) + * + * Example: + * $ export AWS_ACCESS_KEY_ID= + * $ export AWS_SECRET_KEY= + * $ $SPARK_HOME/bin/run-example \ + * org.apache.spark.examples.streaming.KinesisWordCountASL mySparkStream \ + * https://kinesis.us-east-1.amazonaws.com + * + * There is a companion helper class below called KinesisWordCountProducerASL which puts + * dummy data onto the Kinesis stream. + * Usage instructions for KinesisWordCountProducerASL are provided in that class definition. + */ +object KinesisWordCountASL extends Logging { + def main(args: Array[String]) { + /* Check that all required args were passed in. */ + if (args.length < 2) { + System.err.println( + """ + |Usage: KinesisWordCount + | is the name of the Kinesis stream + | is the endpoint of the Kinesis service + | (e.g. https://kinesis.us-east-1.amazonaws.com) + """.stripMargin) + System.exit(1) + } + + StreamingExamples.setStreamingLogLevels() + + /* Populate the appropriate variables from the given args */ + val Array(streamName, endpointUrl) = args + + /* Determine the number of shards from the stream */ + val kinesisClient = new AmazonKinesisClient(new DefaultAWSCredentialsProviderChain()) + kinesisClient.setEndpoint(endpointUrl) + val numShards = kinesisClient.describeStream(streamName).getStreamDescription().getShards() + .size() + + /* In this example, we're going to create 1 Kinesis Worker/Receiver/DStream for each shard. */ + val numStreams = numShards + + /* + * numSparkThreads should be 1 more thread than the number of receivers. + * This leaves one thread available for actually processing the data. + */ + val numSparkThreads = numStreams + 1 + + /* Setup the and SparkConfig and StreamingContext */ + /* Spark Streaming batch interval */ + val batchInterval = Milliseconds(2000) + val sparkConfig = new SparkConf().setAppName("KinesisWordCount") + .setMaster(s"local[$numSparkThreads]") + val ssc = new StreamingContext(sparkConfig, batchInterval) + + /* Kinesis checkpoint interval. Same as batchInterval for this example. */ + val kinesisCheckpointInterval = batchInterval + + /* Create the same number of Kinesis DStreams/Receivers as Kinesis stream's shards */ + val kinesisStreams = (0 until numStreams).map { i => + KinesisUtils.createStream(ssc, streamName, endpointUrl, kinesisCheckpointInterval, + InitialPositionInStream.LATEST, StorageLevel.MEMORY_AND_DISK_2) + } + + /* Union all the streams */ + val unionStreams = ssc.union(kinesisStreams) + + /* Convert each line of Array[Byte] to String, split into words, and count them */ + val words = unionStreams.flatMap(byteArray => new String(byteArray) + .split(" ")) + + /* Map each word to a (word, 1) tuple so we can reduce/aggregate by key. */ + val wordCounts = words.map(word => (word, 1)).reduceByKey(_ + _) + + /* Print the first 10 wordCounts */ + wordCounts.print() + + /* Start the streaming context and await termination */ + ssc.start() + ssc.awaitTermination() + } +} + +/** + * Usage: KinesisWordCountProducerASL + * + * is the name of the Kinesis stream (ie. mySparkStream) + * is the endpoint of the Kinesis service + * (ie. https://kinesis.us-east-1.amazonaws.com) + * is the rate of records per second to put onto the stream + * is the rate of records per second to put onto the stream + * + * Example: + * $ export AWS_ACCESS_KEY_ID= + * $ export AWS_SECRET_KEY= + * $ $SPARK_HOME/bin/run-example \ + * org.apache.spark.examples.streaming.KinesisWordCountProducerASL mySparkStream \ + * https://kinesis.us-east-1.amazonaws.com 10 5 + */ +object KinesisWordCountProducerASL { + def main(args: Array[String]) { + if (args.length < 4) { + System.err.println("Usage: KinesisWordCountProducerASL " + + " ") + System.exit(1) + } + + StreamingExamples.setStreamingLogLevels() + + /* Populate the appropriate variables from the given args */ + val Array(stream, endpoint, recordsPerSecond, wordsPerRecord) = args + + /* Generate the records and return the totals */ + val totals = generate(stream, endpoint, recordsPerSecond.toInt, wordsPerRecord.toInt) + + /* Print the array of (index, total) tuples */ + println("Totals") + totals.foreach(total => println(total.toString())) + } + + def generate(stream: String, + endpoint: String, + recordsPerSecond: Int, + wordsPerRecord: Int): Seq[(Int, Int)] = { + + val MaxRandomInts = 10 + + /* Create the Kinesis client */ + val kinesisClient = new AmazonKinesisClient(new DefaultAWSCredentialsProviderChain()) + kinesisClient.setEndpoint(endpoint) + + println(s"Putting records onto stream $stream and endpoint $endpoint at a rate of" + + s" $recordsPerSecond records per second and $wordsPerRecord words per record"); + + val totals = new Array[Int](MaxRandomInts) + /* Put String records onto the stream per the given recordPerSec and wordsPerRecord */ + for (i <- 1 to 5) { + + /* Generate recordsPerSec records to put onto the stream */ + val records = (1 to recordsPerSecond.toInt).map { recordNum => + /* + * Randomly generate each wordsPerRec words between 0 (inclusive) + * and MAX_RANDOM_INTS (exclusive) + */ + val data = (1 to wordsPerRecord.toInt).map(x => { + /* Generate the random int */ + val randomInt = Random.nextInt(MaxRandomInts) + + /* Keep track of the totals */ + totals(randomInt) += 1 + + randomInt.toString() + }).mkString(" ") + + /* Create a partitionKey based on recordNum */ + val partitionKey = s"partitionKey-$recordNum" + + /* Create a PutRecordRequest with an Array[Byte] version of the data */ + val putRecordRequest = new PutRecordRequest().withStreamName(stream) + .withPartitionKey(partitionKey) + .withData(ByteBuffer.wrap(data.getBytes())); + + /* Put the record onto the stream and capture the PutRecordResult */ + val putRecordResult = kinesisClient.putRecord(putRecordRequest); + } + + /* Sleep for a second */ + Thread.sleep(1000) + println("Sent " + recordsPerSecond + " records") + } + + /* Convert the totals to (index, total) tuple */ + (0 to (MaxRandomInts - 1)).zip(totals) + } +} + +/** + * Utility functions for Spark Streaming examples. + * This has been lifted from the examples/ project to remove the circular dependency. + */ +object StreamingExamples extends Logging { + + /** Set reasonable logging levels for streaming if the user has not configured log4j. */ + def setStreamingLogLevels() { + val log4jInitialized = Logger.getRootLogger.getAllAppenders.hasMoreElements + if (!log4jInitialized) { + // We first log something to initialize Spark's default logging, then we override the + // logging level. + logInfo("Setting log level to [WARN] for streaming example." + + " To override add a custom log4j.properties to the classpath.") + Logger.getRootLogger.setLevel(Level.WARN) + } + } +} diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointState.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointState.scala new file mode 100644 index 0000000000000..0b80b611cdce7 --- /dev/null +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointState.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.streaming.kinesis + +import org.apache.spark.Logging +import org.apache.spark.streaming.Duration +import org.apache.spark.streaming.util.Clock +import org.apache.spark.streaming.util.ManualClock +import org.apache.spark.streaming.util.SystemClock + +/** + * This is a helper class for managing checkpoint clocks. + * + * @param checkpointInterval + * @param currentClock. Default to current SystemClock if none is passed in (mocking purposes) + */ +private[kinesis] class KinesisCheckpointState( + checkpointInterval: Duration, + currentClock: Clock = new SystemClock()) + extends Logging { + + /* Initialize the checkpoint clock using the given currentClock + checkpointInterval millis */ + val checkpointClock = new ManualClock() + checkpointClock.setTime(currentClock.currentTime() + checkpointInterval.milliseconds) + + /** + * Check if it's time to checkpoint based on the current time and the derived time + * for the next checkpoint + * + * @return true if it's time to checkpoint + */ + def shouldCheckpoint(): Boolean = { + new SystemClock().currentTime() > checkpointClock.currentTime() + } + + /** + * Advance the checkpoint clock by the checkpoint interval. + */ + def advanceCheckpoint() = { + checkpointClock.addToTime(checkpointInterval.milliseconds) + } +} diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala new file mode 100644 index 0000000000000..1bd1f324298e7 --- /dev/null +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.streaming.kinesis + +import java.net.InetAddress +import java.util.UUID + +import org.apache.spark.Logging +import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming.Duration +import org.apache.spark.streaming.receiver.Receiver + +import com.amazonaws.auth.AWSCredentialsProvider +import com.amazonaws.auth.DefaultAWSCredentialsProviderChain +import com.amazonaws.services.kinesis.clientlibrary.interfaces.IRecordProcessor +import com.amazonaws.services.kinesis.clientlibrary.interfaces.IRecordProcessorFactory +import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream +import com.amazonaws.services.kinesis.clientlibrary.lib.worker.KinesisClientLibConfiguration +import com.amazonaws.services.kinesis.clientlibrary.lib.worker.Worker + +/** + * Custom AWS Kinesis-specific implementation of Spark Streaming's Receiver. + * This implementation relies on the Kinesis Client Library (KCL) Worker as described here: + * https://github.com/awslabs/amazon-kinesis-client + * This is a custom receiver used with StreamingContext.receiverStream(Receiver) + * as described here: + * http://spark.apache.org/docs/latest/streaming-custom-receivers.html + * Instances of this class will get shipped to the Spark Streaming Workers + * to run within a Spark Executor. + * + * @param appName Kinesis application name. Kinesis Apps are mapped to Kinesis Streams + * by the Kinesis Client Library. If you change the App name or Stream name, + * the KCL will throw errors. This usually requires deleting the backing + * DynamoDB table with the same name this Kinesis application. + * @param streamName Kinesis stream name + * @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com) + * @param checkpointInterval Checkpoint interval for Kinesis checkpointing. + * See the Kinesis Spark Streaming documentation for more + * details on the different types of checkpoints. + * @param initialPositionInStream In the absence of Kinesis checkpoint info, this is the + * worker's initial starting position in the stream. + * The values are either the beginning of the stream + * per Kinesis' limit of 24 hours + * (InitialPositionInStream.TRIM_HORIZON) or + * the tip of the stream (InitialPositionInStream.LATEST). + * @param storageLevel Storage level to use for storing the received objects + * + * @return ReceiverInputDStream[Array[Byte]] + */ +private[kinesis] class KinesisReceiver( + appName: String, + streamName: String, + endpointUrl: String, + checkpointInterval: Duration, + initialPositionInStream: InitialPositionInStream, + storageLevel: StorageLevel) + extends Receiver[Array[Byte]](storageLevel) with Logging { receiver => + + /* + * The following vars are built in the onStart() method which executes in the Spark Worker after + * this code is serialized and shipped remotely. + */ + + /* + * workerId should be based on the ip address of the actual Spark Worker where this code runs + * (not the Driver's ip address.) + */ + var workerId: String = null + + /* + * This impl uses the DefaultAWSCredentialsProviderChain and searches for credentials + * in the following order of precedence: + * Environment Variables - AWS_ACCESS_KEY_ID and AWS_SECRET_KEY + * Java System Properties - aws.accessKeyId and aws.secretKey + * Credential profiles file at the default location (~/.aws/credentials) shared by all + * AWS SDKs and the AWS CLI + * Instance profile credentials delivered through the Amazon EC2 metadata service + */ + var credentialsProvider: AWSCredentialsProvider = null + + /* KCL config instance. */ + var kinesisClientLibConfiguration: KinesisClientLibConfiguration = null + + /* + * RecordProcessorFactory creates impls of IRecordProcessor. + * IRecordProcessor adapts the KCL to our Spark KinesisReceiver via the + * IRecordProcessor.processRecords() method. + * We're using our custom KinesisRecordProcessor in this case. + */ + var recordProcessorFactory: IRecordProcessorFactory = null + + /* + * Create a Kinesis Worker. + * This is the core client abstraction from the Kinesis Client Library (KCL). + * We pass the RecordProcessorFactory from above as well as the KCL config instance. + * A Kinesis Worker can process 1..* shards from the given stream - each with its + * own RecordProcessor. + */ + var worker: Worker = null + + /** + * This is called when the KinesisReceiver starts and must be non-blocking. + * The KCL creates and manages the receiving/processing thread pool through the Worker.run() + * method. + */ + override def onStart() { + workerId = InetAddress.getLocalHost.getHostAddress() + ":" + UUID.randomUUID() + credentialsProvider = new DefaultAWSCredentialsProviderChain() + kinesisClientLibConfiguration = new KinesisClientLibConfiguration(appName, streamName, + credentialsProvider, workerId).withKinesisEndpoint(endpointUrl) + .withInitialPositionInStream(initialPositionInStream).withTaskBackoffTimeMillis(500) + recordProcessorFactory = new IRecordProcessorFactory { + override def createProcessor: IRecordProcessor = new KinesisRecordProcessor(receiver, + workerId, new KinesisCheckpointState(checkpointInterval)) + } + worker = new Worker(recordProcessorFactory, kinesisClientLibConfiguration) + worker.run() + logInfo(s"Started receiver with workerId $workerId") + } + + /** + * This is called when the KinesisReceiver stops. + * The KCL worker.shutdown() method stops the receiving/processing threads. + * The KCL will do its best to drain and checkpoint any in-flight records upon shutdown. + */ + override def onStop() { + worker.shutdown() + logInfo(s"Shut down receiver with workerId $workerId") + workerId = null + credentialsProvider = null + kinesisClientLibConfiguration = null + recordProcessorFactory = null + worker = null + } +} diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala new file mode 100644 index 0000000000000..8ecc2d90160b1 --- /dev/null +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala @@ -0,0 +1,212 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.streaming.kinesis + +import java.util.List + +import scala.collection.JavaConversions.asScalaBuffer +import scala.util.Random + +import org.apache.spark.Logging + +import com.amazonaws.services.kinesis.clientlibrary.exceptions.InvalidStateException +import com.amazonaws.services.kinesis.clientlibrary.exceptions.KinesisClientLibDependencyException +import com.amazonaws.services.kinesis.clientlibrary.exceptions.ShutdownException +import com.amazonaws.services.kinesis.clientlibrary.exceptions.ThrottlingException +import com.amazonaws.services.kinesis.clientlibrary.interfaces.IRecordProcessor +import com.amazonaws.services.kinesis.clientlibrary.interfaces.IRecordProcessorCheckpointer +import com.amazonaws.services.kinesis.clientlibrary.types.ShutdownReason +import com.amazonaws.services.kinesis.model.Record + +/** + * Kinesis-specific implementation of the Kinesis Client Library (KCL) IRecordProcessor. + * This implementation operates on the Array[Byte] from the KinesisReceiver. + * The Kinesis Worker creates an instance of this KinesisRecordProcessor upon startup. + * + * @param receiver Kinesis receiver + * @param workerId for logging purposes + * @param checkpointState represents the checkpoint state including the next checkpoint time. + * It's injected here for mocking purposes. + */ +private[kinesis] class KinesisRecordProcessor( + receiver: KinesisReceiver, + workerId: String, + checkpointState: KinesisCheckpointState) extends IRecordProcessor with Logging { + + /* shardId to be populated during initialize() */ + var shardId: String = _ + + /** + * The Kinesis Client Library calls this method during IRecordProcessor initialization. + * + * @param shardId assigned by the KCL to this particular RecordProcessor. + */ + override def initialize(shardId: String) { + logInfo(s"Initialize: Initializing workerId $workerId with shardId $shardId") + this.shardId = shardId + } + + /** + * This method is called by the KCL when a batch of records is pulled from the Kinesis stream. + * This is the record-processing bridge between the KCL's IRecordProcessor.processRecords() + * and Spark Streaming's Receiver.store(). + * + * @param batch list of records from the Kinesis stream shard + * @param checkpointer used to update Kinesis when this batch has been processed/stored + * in the DStream + */ + override def processRecords(batch: List[Record], checkpointer: IRecordProcessorCheckpointer) { + if (!receiver.isStopped()) { + try { + /* + * Note: If we try to store the raw ByteBuffer from record.getData(), the Spark Streaming + * Receiver.store(ByteBuffer) attempts to deserialize the ByteBuffer using the + * internally-configured Spark serializer (kryo, etc). + * This is not desirable, so we instead store a raw Array[Byte] and decouple + * ourselves from Spark's internal serialization strategy. + */ + batch.foreach(record => receiver.store(record.getData().array())) + + logDebug(s"Stored: Worker $workerId stored ${batch.size} records for shardId $shardId") + + /* + * Checkpoint the sequence number of the last record successfully processed/stored + * in the batch. + * In this implementation, we're checkpointing after the given checkpointIntervalMillis. + * Note that this logic requires that processRecords() be called AND that it's time to + * checkpoint. I point this out because there is no background thread running the + * checkpointer. Checkpointing is tested and trigger only when a new batch comes in. + * If the worker is shutdown cleanly, checkpoint will happen (see shutdown() below). + * However, if the worker dies unexpectedly, a checkpoint may not happen. + * This could lead to records being processed more than once. + */ + if (checkpointState.shouldCheckpoint()) { + /* Perform the checkpoint */ + KinesisRecordProcessor.retryRandom(checkpointer.checkpoint(), 4, 100) + + /* Update the next checkpoint time */ + checkpointState.advanceCheckpoint() + + logDebug(s"Checkpoint: WorkerId $workerId completed checkpoint of ${batch.size}" + + s" records for shardId $shardId") + logDebug(s"Checkpoint: Next checkpoint is at " + + s" ${checkpointState.checkpointClock.currentTime()} for shardId $shardId") + } + } catch { + case e: Throwable => { + /* + * If there is a failure within the batch, the batch will not be checkpointed. + * This will potentially cause records since the last checkpoint to be processed + * more than once. + */ + logError(s"Exception: WorkerId $workerId encountered and exception while storing " + + " or checkpointing a batch for workerId $workerId and shardId $shardId.", e) + + /* Rethrow the exception to the Kinesis Worker that is managing this RecordProcessor.*/ + throw e + } + } + } else { + /* RecordProcessor has been stopped. */ + logInfo(s"Stopped: The Spark KinesisReceiver has stopped for workerId $workerId" + + s" and shardId $shardId. No more records will be processed.") + } + } + + /** + * Kinesis Client Library is shutting down this Worker for 1 of 2 reasons: + * 1) the stream is resharding by splitting or merging adjacent shards + * (ShutdownReason.TERMINATE) + * 2) the failed or latent Worker has stopped sending heartbeats for whatever reason + * (ShutdownReason.ZOMBIE) + * + * @param checkpointer used to perform a Kinesis checkpoint for ShutdownReason.TERMINATE + * @param reason for shutdown (ShutdownReason.TERMINATE or ShutdownReason.ZOMBIE) + */ + override def shutdown(checkpointer: IRecordProcessorCheckpointer, reason: ShutdownReason) { + logInfo(s"Shutdown: Shutting down workerId $workerId with reason $reason") + reason match { + /* + * TERMINATE Use Case. Checkpoint. + * Checkpoint to indicate that all records from the shard have been drained and processed. + * It's now OK to read from the new shards that resulted from a resharding event. + */ + case ShutdownReason.TERMINATE => + KinesisRecordProcessor.retryRandom(checkpointer.checkpoint(), 4, 100) + + /* + * ZOMBIE Use Case. NoOp. + * No checkpoint because other workers may have taken over and already started processing + * the same records. + * This may lead to records being processed more than once. + */ + case ShutdownReason.ZOMBIE => + + /* Unknown reason. NoOp */ + case _ => + } + } +} + +private[kinesis] object KinesisRecordProcessor extends Logging { + /** + * Retry the given amount of times with a random backoff time (millis) less than the + * given maxBackOffMillis + * + * @param expression expression to evalute + * @param numRetriesLeft number of retries left + * @param maxBackOffMillis: max millis between retries + * + * @return evaluation of the given expression + * @throws Unretryable exception, unexpected exception, + * or any exception that persists after numRetriesLeft reaches 0 + */ + @annotation.tailrec + def retryRandom[T](expression: => T, numRetriesLeft: Int, maxBackOffMillis: Int): T = { + util.Try { expression } match { + /* If the function succeeded, evaluate to x. */ + case util.Success(x) => x + /* If the function failed, either retry or throw the exception */ + case util.Failure(e) => e match { + /* Retry: Throttling or other Retryable exception has occurred */ + case _: ThrottlingException | _: KinesisClientLibDependencyException if numRetriesLeft > 1 + => { + val backOffMillis = Random.nextInt(maxBackOffMillis) + Thread.sleep(backOffMillis) + logError(s"Retryable Exception: Random backOffMillis=${backOffMillis}", e) + retryRandom(expression, numRetriesLeft - 1, maxBackOffMillis) + } + /* Throw: Shutdown has been requested by the Kinesis Client Library.*/ + case _: ShutdownException => { + logError(s"ShutdownException: Caught shutdown exception, skipping checkpoint.", e) + throw e + } + /* Throw: Non-retryable exception has occurred with the Kinesis Client Library */ + case _: InvalidStateException => { + logError(s"InvalidStateException: Cannot save checkpoint to the DynamoDB table used" + + s" by the Amazon Kinesis Client Library. Table likely doesn't exist.", e) + throw e + } + /* Throw: Unexpected exception has occurred */ + case _ => { + logError(s"Unexpected, non-retryable exception.", e) + throw e + } + } + } + } +} diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala new file mode 100644 index 0000000000000..713cac0e293c0 --- /dev/null +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.streaming.kinesis + +import org.apache.spark.annotation.Experimental +import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming.Duration +import org.apache.spark.streaming.StreamingContext +import org.apache.spark.streaming.api.java.JavaReceiverInputDStream +import org.apache.spark.streaming.api.java.JavaStreamingContext +import org.apache.spark.streaming.dstream.ReceiverInputDStream + +import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream + + +/** + * Helper class to create Amazon Kinesis Input Stream + * :: Experimental :: + */ +@Experimental +object KinesisUtils { + /** + * Create an InputDStream that pulls messages from a Kinesis stream. + * + * @param ssc StreamingContext object + * @param streamName Kinesis stream name + * @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com) + * @param checkpointInterval Checkpoint interval for Kinesis checkpointing. + * See the Kinesis Spark Streaming documentation for more + * details on the different types of checkpoints. + * @param initialPositionInStream In the absence of Kinesis checkpoint info, this is the + * worker's initial starting position in the stream. + * The values are either the beginning of the stream + * per Kinesis' limit of 24 hours + * (InitialPositionInStream.TRIM_HORIZON) or + * the tip of the stream (InitialPositionInStream.LATEST). + * @param storageLevel Storage level to use for storing the received objects + * + * @return ReceiverInputDStream[Array[Byte]] + */ + def createStream( + ssc: StreamingContext, + streamName: String, + endpointUrl: String, + checkpointInterval: Duration, + initialPositionInStream: InitialPositionInStream, + storageLevel: StorageLevel): ReceiverInputDStream[Array[Byte]] = { + ssc.receiverStream(new KinesisReceiver(ssc.sc.appName, streamName, endpointUrl, + checkpointInterval, initialPositionInStream, storageLevel)) + } + + /** + * Create a Java-friendly InputDStream that pulls messages from a Kinesis stream. + * + * @param jssc Java StreamingContext object + * @param ssc StreamingContext object + * @param streamName Kinesis stream name + * @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com) + * @param checkpointInterval Checkpoint interval for Kinesis checkpointing. + * See the Kinesis Spark Streaming documentation for more + * details on the different types of checkpoints. + * @param initialPositionInStream In the absence of Kinesis checkpoint info, this is the + * worker's initial starting position in the stream. + * The values are either the beginning of the stream + * per Kinesis' limit of 24 hours + * (InitialPositionInStream.TRIM_HORIZON) or + * the tip of the stream (InitialPositionInStream.LATEST). + * @param storageLevel Storage level to use for storing the received objects + * + * @return JavaReceiverInputDStream[Array[Byte]] + */ + def createStream( + jssc: JavaStreamingContext, + streamName: String, + endpointUrl: String, + checkpointInterval: Duration, + initialPositionInStream: InitialPositionInStream, + storageLevel: StorageLevel): JavaReceiverInputDStream[Array[Byte]] = { + jssc.receiverStream(new KinesisReceiver(jssc.ssc.sc.appName, streamName, + endpointUrl, checkpointInterval, initialPositionInStream, storageLevel)) + } +} diff --git a/extras/kinesis-asl/src/test/java/org/apache/spark/streaming/kinesis/JavaKinesisStreamSuite.java b/extras/kinesis-asl/src/test/java/org/apache/spark/streaming/kinesis/JavaKinesisStreamSuite.java new file mode 100644 index 0000000000000..87954a31f60ce --- /dev/null +++ b/extras/kinesis-asl/src/test/java/org/apache/spark/streaming/kinesis/JavaKinesisStreamSuite.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kinesis; + +import org.apache.spark.storage.StorageLevel; +import org.apache.spark.streaming.Duration; +import org.apache.spark.streaming.LocalJavaStreamingContext; +import org.apache.spark.streaming.api.java.JavaDStream; +import org.junit.Test; + +import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream; + +/** + * Demonstrate the use of the KinesisUtils Java API + */ +public class JavaKinesisStreamSuite extends LocalJavaStreamingContext { + @Test + public void testKinesisStream() { + // Tests the API, does not actually test data receiving + JavaDStream kinesisStream = KinesisUtils.createStream(ssc, "mySparkStream", + "https://kinesis.us-west-2.amazonaws.com", new Duration(2000), + InitialPositionInStream.LATEST, StorageLevel.MEMORY_AND_DISK_2()); + + ssc.stop(); + } +} diff --git a/extras/kinesis-asl/src/test/resources/log4j.properties b/extras/kinesis-asl/src/test/resources/log4j.properties new file mode 100644 index 0000000000000..e01e049595475 --- /dev/null +++ b/extras/kinesis-asl/src/test/resources/log4j.properties @@ -0,0 +1,26 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +log4j.rootCategory=INFO, file +# log4j.appender.file=org.apache.log4j.FileAppender +log4j.appender.file=org.apache.log4j.FileAppender +log4j.appender.file.append=false +log4j.appender.file.file=target/unit-tests.log +log4j.appender.file.layout=org.apache.log4j.PatternLayout +log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %p %c{1}: %m%n + +# Ignore messages below warning level from Jetty, because it's a bit verbose +log4j.logger.org.eclipse.jetty=WARN diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala new file mode 100644 index 0000000000000..41dbd64c2b1fa --- /dev/null +++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala @@ -0,0 +1,275 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.streaming.kinesis + +import java.nio.ByteBuffer + +import scala.collection.JavaConversions.seqAsJavaList + +import org.apache.spark.annotation.Experimental +import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming.Milliseconds +import org.apache.spark.streaming.Seconds +import org.apache.spark.streaming.StreamingContext +import org.apache.spark.streaming.TestSuiteBase +import org.apache.spark.streaming.util.Clock +import org.apache.spark.streaming.util.ManualClock +import org.scalatest.BeforeAndAfter +import org.scalatest.Matchers +import org.scalatest.mock.EasyMockSugar + +import com.amazonaws.services.kinesis.clientlibrary.exceptions.InvalidStateException +import com.amazonaws.services.kinesis.clientlibrary.exceptions.KinesisClientLibDependencyException +import com.amazonaws.services.kinesis.clientlibrary.exceptions.ShutdownException +import com.amazonaws.services.kinesis.clientlibrary.exceptions.ThrottlingException +import com.amazonaws.services.kinesis.clientlibrary.interfaces.IRecordProcessorCheckpointer +import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream +import com.amazonaws.services.kinesis.clientlibrary.types.ShutdownReason +import com.amazonaws.services.kinesis.model.Record + +/** + * Suite of Kinesis streaming receiver tests focusing mostly on the KinesisRecordProcessor + */ +class KinesisReceiverSuite extends TestSuiteBase with Matchers with BeforeAndAfter + with EasyMockSugar { + + val app = "TestKinesisReceiver" + val stream = "mySparkStream" + val endpoint = "endpoint-url" + val workerId = "dummyWorkerId" + val shardId = "dummyShardId" + + val record1 = new Record() + record1.setData(ByteBuffer.wrap("Spark In Action".getBytes())) + val record2 = new Record() + record2.setData(ByteBuffer.wrap("Learning Spark".getBytes())) + val batch = List[Record](record1, record2) + + var receiverMock: KinesisReceiver = _ + var checkpointerMock: IRecordProcessorCheckpointer = _ + var checkpointClockMock: ManualClock = _ + var checkpointStateMock: KinesisCheckpointState = _ + var currentClockMock: Clock = _ + + override def beforeFunction() = { + receiverMock = mock[KinesisReceiver] + checkpointerMock = mock[IRecordProcessorCheckpointer] + checkpointClockMock = mock[ManualClock] + checkpointStateMock = mock[KinesisCheckpointState] + currentClockMock = mock[Clock] + } + + test("kinesis utils api") { + val ssc = new StreamingContext(master, framework, batchDuration) + // Tests the API, does not actually test data receiving + val kinesisStream = KinesisUtils.createStream(ssc, "mySparkStream", + "https://kinesis.us-west-2.amazonaws.com", Seconds(2), + InitialPositionInStream.LATEST, StorageLevel.MEMORY_AND_DISK_2); + ssc.stop() + } + + test("process records including store and checkpoint") { + val expectedCheckpointIntervalMillis = 10 + expecting { + receiverMock.isStopped().andReturn(false).once() + receiverMock.store(record1.getData().array()).once() + receiverMock.store(record2.getData().array()).once() + checkpointStateMock.shouldCheckpoint().andReturn(true).once() + checkpointerMock.checkpoint().once() + checkpointStateMock.advanceCheckpoint().once() + } + whenExecuting(receiverMock, checkpointerMock, checkpointStateMock) { + val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId, + checkpointStateMock) + recordProcessor.processRecords(batch, checkpointerMock) + } + } + + test("shouldn't store and checkpoint when receiver is stopped") { + expecting { + receiverMock.isStopped().andReturn(true).once() + } + whenExecuting(receiverMock, checkpointerMock, checkpointStateMock) { + val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId, + checkpointStateMock) + recordProcessor.processRecords(batch, checkpointerMock) + } + } + + test("shouldn't checkpoint when exception occurs during store") { + expecting { + receiverMock.isStopped().andReturn(false).once() + receiverMock.store(record1.getData().array()).andThrow(new RuntimeException()).once() + } + whenExecuting(receiverMock, checkpointerMock, checkpointStateMock) { + intercept[RuntimeException] { + val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId, + checkpointStateMock) + recordProcessor.processRecords(batch, checkpointerMock) + } + } + } + + test("should set checkpoint time to currentTime + checkpoint interval upon instantiation") { + expecting { + currentClockMock.currentTime().andReturn(0).once() + } + whenExecuting(currentClockMock) { + val checkpointIntervalMillis = 10 + val checkpointState = new KinesisCheckpointState(Milliseconds(checkpointIntervalMillis), currentClockMock) + assert(checkpointState.checkpointClock.currentTime() == checkpointIntervalMillis) + } + } + + test("should checkpoint if we have exceeded the checkpoint interval") { + expecting { + currentClockMock.currentTime().andReturn(0).once() + } + whenExecuting(currentClockMock) { + val checkpointState = new KinesisCheckpointState(Milliseconds(Long.MinValue), currentClockMock) + assert(checkpointState.shouldCheckpoint()) + } + } + + test("shouldn't checkpoint if we have not exceeded the checkpoint interval") { + expecting { + currentClockMock.currentTime().andReturn(0).once() + } + whenExecuting(currentClockMock) { + val checkpointState = new KinesisCheckpointState(Milliseconds(Long.MaxValue), currentClockMock) + assert(!checkpointState.shouldCheckpoint()) + } + } + + test("should add to time when advancing checkpoint") { + expecting { + currentClockMock.currentTime().andReturn(0).once() + } + whenExecuting(currentClockMock) { + val checkpointIntervalMillis = 10 + val checkpointState = new KinesisCheckpointState(Milliseconds(checkpointIntervalMillis), currentClockMock) + assert(checkpointState.checkpointClock.currentTime() == checkpointIntervalMillis) + checkpointState.advanceCheckpoint() + assert(checkpointState.checkpointClock.currentTime() == (2 * checkpointIntervalMillis)) + } + } + + test("shutdown should checkpoint if the reason is TERMINATE") { + expecting { + checkpointerMock.checkpoint().once() + } + whenExecuting(checkpointerMock, checkpointStateMock) { + val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId, + checkpointStateMock) + val reason = ShutdownReason.TERMINATE + recordProcessor.shutdown(checkpointerMock, reason) + } + } + + test("shutdown should not checkpoint if the reason is something other than TERMINATE") { + expecting { + } + whenExecuting(checkpointerMock, checkpointStateMock) { + val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId, + checkpointStateMock) + recordProcessor.shutdown(checkpointerMock, ShutdownReason.ZOMBIE) + recordProcessor.shutdown(checkpointerMock, null) + } + } + + test("retry success on first attempt") { + val expectedIsStopped = false + expecting { + receiverMock.isStopped().andReturn(expectedIsStopped).once() + } + whenExecuting(receiverMock) { + val actualVal = KinesisRecordProcessor.retryRandom(receiverMock.isStopped(), 2, 100) + assert(actualVal == expectedIsStopped) + } + } + + test("retry success on second attempt after a Kinesis throttling exception") { + val expectedIsStopped = false + expecting { + receiverMock.isStopped().andThrow(new ThrottlingException("error message")) + .andReturn(expectedIsStopped).once() + } + whenExecuting(receiverMock) { + val actualVal = KinesisRecordProcessor.retryRandom(receiverMock.isStopped(), 2, 100) + assert(actualVal == expectedIsStopped) + } + } + + test("retry success on second attempt after a Kinesis dependency exception") { + val expectedIsStopped = false + expecting { + receiverMock.isStopped().andThrow(new KinesisClientLibDependencyException("error message")) + .andReturn(expectedIsStopped).once() + } + whenExecuting(receiverMock) { + val actualVal = KinesisRecordProcessor.retryRandom(receiverMock.isStopped(), 2, 100) + assert(actualVal == expectedIsStopped) + } + } + + test("retry failed after a shutdown exception") { + expecting { + checkpointerMock.checkpoint().andThrow(new ShutdownException("error message")).once() + } + whenExecuting(checkpointerMock) { + intercept[ShutdownException] { + KinesisRecordProcessor.retryRandom(checkpointerMock.checkpoint(), 2, 100) + } + } + } + + test("retry failed after an invalid state exception") { + expecting { + checkpointerMock.checkpoint().andThrow(new InvalidStateException("error message")).once() + } + whenExecuting(checkpointerMock) { + intercept[InvalidStateException] { + KinesisRecordProcessor.retryRandom(checkpointerMock.checkpoint(), 2, 100) + } + } + } + + test("retry failed after unexpected exception") { + expecting { + checkpointerMock.checkpoint().andThrow(new RuntimeException("error message")).once() + } + whenExecuting(checkpointerMock) { + intercept[RuntimeException] { + KinesisRecordProcessor.retryRandom(checkpointerMock.checkpoint(), 2, 100) + } + } + } + + test("retry failed after exhausing all retries") { + val expectedErrorMessage = "final try error message" + expecting { + checkpointerMock.checkpoint().andThrow(new ThrottlingException("error message")) + .andThrow(new ThrottlingException(expectedErrorMessage)).once() + } + whenExecuting(checkpointerMock) { + val exception = intercept[RuntimeException] { + KinesisRecordProcessor.retryRandom(checkpointerMock.checkpoint(), 2, 100) + } + exception.getMessage().shouldBe(expectedErrorMessage) + } + } +} diff --git a/pom.xml b/pom.xml index 99ae4b8b33f94..a42759169149b 100644 --- a/pom.xml +++ b/pom.xml @@ -134,6 +134,8 @@ 3.0.0 1.7.6 0.7.1 + 1.8.3 + 1.1.0 64m 512m @@ -1011,6 +1013,14 @@ + + + kinesis-asl + + extras/kinesis-asl + + + java8-tests diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 1d7cc6dd6aef3..aac621fe53938 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -37,8 +37,8 @@ object BuildCommons { "spark", "sql", "streaming", "streaming-flume-sink", "streaming-flume", "streaming-kafka", "streaming-mqtt", "streaming-twitter", "streaming-zeromq").map(ProjectRef(buildLocation, _)) - val optionallyEnabledProjects@Seq(yarn, yarnStable, yarnAlpha, java8Tests, sparkGangliaLgpl) = - Seq("yarn", "yarn-stable", "yarn-alpha", "java8-tests", "ganglia-lgpl") + val optionallyEnabledProjects@Seq(yarn, yarnStable, yarnAlpha, java8Tests, sparkGangliaLgpl, sparkKinesisAsl) = + Seq("yarn", "yarn-stable", "yarn-alpha", "java8-tests", "ganglia-lgpl", "kinesis-asl") .map(ProjectRef(buildLocation, _)) val assemblyProjects@Seq(assembly, examples) = Seq("assembly", "examples") @@ -62,7 +62,7 @@ object SparkBuild extends PomBuild { var isAlphaYarn = false var profiles: mutable.Seq[String] = mutable.Seq.empty if (Properties.envOrNone("SPARK_GANGLIA_LGPL").isDefined) { - println("NOTE: SPARK_GANGLIA_LGPL is deprecated, please use -Pganglia-lgpl flag.") + println("NOTE: SPARK_GANGLIA_LGPL is deprecated, please use -Pspark-ganglia-lgpl flag.") profiles ++= Seq("spark-ganglia-lgpl") } if (Properties.envOrNone("SPARK_HIVE").isDefined) { From 7924d72cf8aae945d72f355c54c4fcb3d62e6c48 Mon Sep 17 00:00:00 2001 From: GuoQiang Li Date: Sat, 2 Aug 2014 13:55:28 -0700 Subject: [PATCH 005/231] SPARK-2804: Remove scalalogging-slf4j dependency This also Closes #1701. Author: GuoQiang Li Closes #1208 from witgo/SPARK-1470 and squashes the following commits: 422646b [GuoQiang Li] Remove scalalogging-slf4j dependency --- .../main/scala/org/apache/spark/Logging.scala | 10 ++++++--- sql/catalyst/pom.xml | 5 ----- .../sql/catalyst/analysis/Analyzer.scala | 4 ++-- .../catalyst/analysis/HiveTypeCoercion.scala | 8 +++---- .../catalyst/expressions/BoundAttribute.scala | 2 +- .../codegen/GenerateOrdering.scala | 4 ++-- .../apache/spark/sql/catalyst/package.scala | 1 - .../sql/catalyst/planning/QueryPlanner.scala | 2 +- .../sql/catalyst/planning/patterns.scala | 6 ++--- .../spark/sql/catalyst/rules/Rule.scala | 2 +- .../sql/catalyst/rules/RuleExecutor.scala | 12 +++++----- .../spark/sql/catalyst/trees/package.scala | 8 ++++--- .../org/apache/spark/sql/SQLContext.scala | 2 +- .../CompressibleColumnBuilder.scala | 5 +++-- .../apache/spark/sql/execution/Exchange.scala | 2 +- .../org/apache/spark/sql/json/JsonRDD.scala | 2 +- .../scala/org/apache/spark/sql/package.scala | 2 -- .../spark/sql/columnar/ColumnTypeSuite.scala | 4 ++-- .../hive/thriftserver/HiveThriftServer2.scala | 12 +++++----- .../hive/thriftserver/SparkSQLCLIDriver.scala | 2 +- .../hive/thriftserver/SparkSQLDriver.scala | 6 ++--- .../sql/hive/thriftserver/SparkSQLEnv.scala | 6 ++--- .../server/SparkSQLOperationManager.scala | 13 ++++++----- .../thriftserver/HiveThriftServer2Suite.scala | 2 +- .../apache/spark/sql/hive/HiveContext.scala | 2 +- .../spark/sql/hive/HiveMetastoreCatalog.scala | 3 ++- .../org/apache/spark/sql/hive/TestHive.scala | 10 ++++----- .../org/apache/spark/sql/hive/hiveUdfs.scala | 4 ++-- .../hive/execution/HiveComparisonTest.scala | 22 +++++++++---------- .../hive/execution/HiveQueryFileTest.scala | 2 +- 30 files changed, 83 insertions(+), 82 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/Logging.scala b/core/src/main/scala/org/apache/spark/Logging.scala index 807ef3e9c9d60..d4f2624061e35 100644 --- a/core/src/main/scala/org/apache/spark/Logging.scala +++ b/core/src/main/scala/org/apache/spark/Logging.scala @@ -39,13 +39,17 @@ trait Logging { // be serialized and used on another machine @transient private var log_ : Logger = null + // Method to get the logger name for this object + protected def logName = { + // Ignore trailing $'s in the class names for Scala objects + this.getClass.getName.stripSuffix("$") + } + // Method to get or create the logger for this object protected def log: Logger = { if (log_ == null) { initializeIfNecessary() - var className = this.getClass.getName - // Ignore trailing $'s in the class names for Scala objects - log_ = LoggerFactory.getLogger(className.stripSuffix("$")) + log_ = LoggerFactory.getLogger(logName) } log_ } diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index 54fa96baa1e18..58d44e7923bee 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -54,11 +54,6 @@ spark-core_${scala.binary.version} ${project.version} - - com.typesafe - scalalogging-slf4j_${scala.binary.version} - 1.0.1 - org.scalatest scalatest_${scala.binary.version} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 74c0104e5b17f..2ba68cab115fb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -109,12 +109,12 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool object ResolveReferences extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case q: LogicalPlan if q.childrenResolved => - logger.trace(s"Attempting to resolve ${q.simpleString}") + logTrace(s"Attempting to resolve ${q.simpleString}") q transformExpressions { case u @ UnresolvedAttribute(name) => // Leave unchanged if resolution fails. Hopefully will be resolved next round. val result = q.resolve(name).getOrElse(u) - logger.debug(s"Resolving $u to $result") + logDebug(s"Resolving $u to $result") result } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 47c7ad076ad07..e94f2a3bea63e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -75,7 +75,7 @@ trait HiveTypeCoercion { // Leave the same if the dataTypes match. case Some(newType) if a.dataType == newType.dataType => a case Some(newType) => - logger.debug(s"Promoting $a to $newType in ${q.simpleString}}") + logDebug(s"Promoting $a to $newType in ${q.simpleString}}") newType } } @@ -154,7 +154,7 @@ trait HiveTypeCoercion { (Alias(Cast(l, StringType), l.name)(), r) case (l, r) if l.dataType != r.dataType => - logger.debug(s"Resolving mismatched union input ${l.dataType}, ${r.dataType}") + logDebug(s"Resolving mismatched union input ${l.dataType}, ${r.dataType}") findTightestCommonType(l.dataType, r.dataType).map { widestType => val newLeft = if (l.dataType == widestType) l else Alias(Cast(l, widestType), l.name)() @@ -170,7 +170,7 @@ trait HiveTypeCoercion { val newLeft = if (castedLeft.map(_.dataType) != left.output.map(_.dataType)) { - logger.debug(s"Widening numeric types in union $castedLeft ${left.output}") + logDebug(s"Widening numeric types in union $castedLeft ${left.output}") Project(castedLeft, left) } else { left @@ -178,7 +178,7 @@ trait HiveTypeCoercion { val newRight = if (castedRight.map(_.dataType) != right.output.map(_.dataType)) { - logger.debug(s"Widening numeric types in union $castedRight ${right.output}") + logDebug(s"Widening numeric types in union $castedRight ${right.output}") Project(castedRight, right) } else { right diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index f38f99569f207..0913f15888780 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.Logging +import org.apache.spark.Logging import org.apache.spark.sql.catalyst.errors.attachTree import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.catalyst.trees diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala index 4211998f7511a..094ff14552283 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen -import com.typesafe.scalalogging.slf4j.Logging +import org.apache.spark.Logging import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.types.{StringType, NumericType} @@ -92,7 +92,7 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[Row]] wit } new $orderingName() """ - logger.debug(s"Generated Ordering: $code") + logDebug(s"Generated Ordering: $code") toolBox.eval(code).asInstanceOf[Ordering[Row]] } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/package.scala index ca9642954eb27..bdd07bbeb2230 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/package.scala @@ -25,5 +25,4 @@ package object catalyst { */ protected[catalyst] object ScalaReflectionLock - protected[catalyst] type Logging = com.typesafe.scalalogging.slf4j.Logging } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala index 781ba489b44c6..5839c9f7c43ef 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.planning -import org.apache.spark.sql.catalyst.Logging +import org.apache.spark.Logging import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.trees.TreeNode diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index bc763a4e06e67..90923fe31a063 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.planning import scala.annotation.tailrec import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.Logging +import org.apache.spark.Logging import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ @@ -184,7 +184,7 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper { def unapply(plan: LogicalPlan): Option[ReturnType] = plan match { case join @ Join(left, right, joinType, condition) => - logger.debug(s"Considering join on: $condition") + logDebug(s"Considering join on: $condition") // Find equi-join predicates that can be evaluated before the join, and thus can be used // as join keys. val (joinPredicates, otherPredicates) = @@ -202,7 +202,7 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper { val rightKeys = joinKeys.map(_._2) if (joinKeys.nonEmpty) { - logger.debug(s"leftKeys:${leftKeys} | rightKeys:${rightKeys}") + logDebug(s"leftKeys:${leftKeys} | rightKeys:${rightKeys}") Some((joinType, leftKeys, rightKeys, otherPredicates.reduceOption(And), left, right)) } else { None diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/Rule.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/Rule.scala index f8960b3fe7a17..03414b2301e81 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/Rule.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/Rule.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.rules -import org.apache.spark.sql.catalyst.Logging +import org.apache.spark.Logging import org.apache.spark.sql.catalyst.trees.TreeNode abstract class Rule[TreeType <: TreeNode[_]] extends Logging { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala index 6aa407c836aec..d192b151ac1c3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.rules -import org.apache.spark.sql.catalyst.Logging +import org.apache.spark.Logging import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.catalyst.util.sideBySide @@ -60,7 +60,7 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { case (plan, rule) => val result = rule(plan) if (!result.fastEquals(plan)) { - logger.trace( + logTrace( s""" |=== Applying Rule ${rule.ruleName} === |${sideBySide(plan.treeString, result.treeString).mkString("\n")} @@ -73,26 +73,26 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { if (iteration > batch.strategy.maxIterations) { // Only log if this is a rule that is supposed to run more than once. if (iteration != 2) { - logger.info(s"Max iterations (${iteration - 1}) reached for batch ${batch.name}") + logInfo(s"Max iterations (${iteration - 1}) reached for batch ${batch.name}") } continue = false } if (curPlan.fastEquals(lastPlan)) { - logger.trace(s"Fixed point reached for batch ${batch.name} after $iteration iterations.") + logTrace(s"Fixed point reached for batch ${batch.name} after $iteration iterations.") continue = false } lastPlan = curPlan } if (!batchStartPlan.fastEquals(curPlan)) { - logger.debug( + logDebug( s""" |=== Result of Batch ${batch.name} === |${sideBySide(plan.treeString, curPlan.treeString).mkString("\n")} """.stripMargin) } else { - logger.trace(s"Batch ${batch.name} has no effect.") + logTrace(s"Batch ${batch.name} has no effect.") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/package.scala index 9a28d035a10a3..d725a92c06f7b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/package.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst +import org.apache.spark.Logging + /** * A library for easily manipulating trees of operators. Operators that extend TreeNode are * granted the following interface: @@ -31,8 +33,8 @@ package org.apache.spark.sql.catalyst *
  • debugging support - pretty printing, easy splicing of trees, etc.
  • * */ -package object trees { +package object trees extends Logging { // Since we want tree nodes to be lightweight, we create one logger for all treenode instances. - protected val logger = - com.typesafe.scalalogging.slf4j.Logger(org.slf4j.LoggerFactory.getLogger("catalyst.trees")) + protected override def logName = "catalyst.trees" + } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index dad71079c29b9..00dd34aabc389 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -36,7 +36,7 @@ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.SparkStrategies import org.apache.spark.sql.json._ import org.apache.spark.sql.parquet.ParquetRelation -import org.apache.spark.SparkContext +import org.apache.spark.{Logging, SparkContext} /** * :: AlphaComponent :: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala index 4c6675c3c87bf..6ad12a0dcb64d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala @@ -19,7 +19,8 @@ package org.apache.spark.sql.columnar.compression import java.nio.{ByteBuffer, ByteOrder} -import org.apache.spark.sql.{Logging, Row} +import org.apache.spark.Logging +import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.types.NativeType import org.apache.spark.sql.columnar.{ColumnBuilder, NativeColumnBuilder} @@ -101,7 +102,7 @@ private[sql] trait CompressibleColumnBuilder[T <: NativeType] copyColumnHeader(rawBuffer, compressedBuffer) - logger.info(s"Compressor for [$columnName]: $encoder, ratio: ${encoder.compressionRatio}") + logInfo(s"Compressor for [$columnName]: $encoder, ratio: ${encoder.compressionRatio}") encoder.compress(rawBuffer, compressedBuffer, columnType) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 30712f03cab4c..77dc2ad733215 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -101,7 +101,7 @@ private[sql] case class AddExchange(sqlContext: SQLContext) extends Rule[SparkPl !operator.requiredChildDistribution.zip(operator.children).map { case (required, child) => val valid = child.outputPartitioning.satisfies(required) - logger.debug( + logDebug( s"${if (valid) "Valid" else "Invalid"} distribution," + s"required: $required current: ${child.outputPartitioning}") valid diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala index 70db1ebd3a3e1..a3d2a1c7a51f8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.catalyst.ScalaReflection -import org.apache.spark.sql.Logging +import org.apache.spark.Logging private[sql] object JsonRDD extends Logging { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/package.scala index 0995a4eb6299f..f513eae9c2d13 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/package.scala @@ -32,8 +32,6 @@ import org.apache.spark.annotation.DeveloperApi */ package object sql { - protected[sql] type Logging = com.typesafe.scalalogging.slf4j.Logging - /** * :: DeveloperApi :: * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala index 829342215e691..75f653f3280bd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala @@ -22,7 +22,7 @@ import java.sql.Timestamp import org.scalatest.FunSuite -import org.apache.spark.sql.Logging +import org.apache.spark.Logging import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.columnar.ColumnarTestUtils._ import org.apache.spark.sql.execution.SparkSqlSerializer @@ -166,7 +166,7 @@ class ColumnTypeSuite extends FunSuite with Logging { buffer.rewind() seq.foreach { expected => - logger.info("buffer = " + buffer + ", expected = " + expected) + logInfo("buffer = " + buffer + ", expected = " + expected) val extracted = columnType.extract(buffer) assert( expected === extracted, diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala index ddbc2a79fb512..08d3f983d9e71 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala @@ -25,7 +25,7 @@ import org.apache.hadoop.hive.ql.session.SessionState import org.apache.hive.service.cli.thrift.ThriftBinaryCLIService import org.apache.hive.service.server.{HiveServer2, ServerOptionsProcessor} -import org.apache.spark.sql.Logging +import org.apache.spark.Logging import org.apache.spark.sql.hive.HiveContext import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._ @@ -40,7 +40,7 @@ private[hive] object HiveThriftServer2 extends Logging { val optionsProcessor = new ServerOptionsProcessor("HiveThriftServer2") if (!optionsProcessor.process(args)) { - logger.warn("Error starting HiveThriftServer2 with given arguments") + logWarning("Error starting HiveThriftServer2 with given arguments") System.exit(-1) } @@ -49,12 +49,12 @@ private[hive] object HiveThriftServer2 extends Logging { // Set all properties specified via command line. val hiveConf: HiveConf = ss.getConf hiveConf.getAllProperties.toSeq.sortBy(_._1).foreach { case (k, v) => - logger.debug(s"HiveConf var: $k=$v") + logDebug(s"HiveConf var: $k=$v") } SessionState.start(ss) - logger.info("Starting SparkContext") + logInfo("Starting SparkContext") SparkSQLEnv.init() SessionState.start(ss) @@ -70,10 +70,10 @@ private[hive] object HiveThriftServer2 extends Logging { val server = new HiveThriftServer2(SparkSQLEnv.hiveContext) server.init(hiveConf) server.start() - logger.info("HiveThriftServer2 started") + logInfo("HiveThriftServer2 started") } catch { case e: Exception => - logger.error("Error starting HiveThriftServer2", e) + logError("Error starting HiveThriftServer2", e) System.exit(-1) } } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala index cb17d7ce58ea0..4d0c506c5a397 100755 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala @@ -37,7 +37,7 @@ import org.apache.hadoop.hive.ql.session.SessionState import org.apache.hadoop.hive.shims.ShimLoader import org.apache.thrift.transport.TSocket -import org.apache.spark.sql.Logging +import org.apache.spark.Logging private[hive] object SparkSQLCLIDriver { private var prompt = "spark-sql" diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala index a56b19a4bcda0..d362d599d08ca 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala @@ -26,7 +26,7 @@ import org.apache.hadoop.hive.metastore.api.{FieldSchema, Schema} import org.apache.hadoop.hive.ql.Driver import org.apache.hadoop.hive.ql.processors.CommandProcessorResponse -import org.apache.spark.sql.Logging +import org.apache.spark.Logging import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes} private[hive] class SparkSQLDriver(val context: HiveContext = SparkSQLEnv.hiveContext) @@ -40,7 +40,7 @@ private[hive] class SparkSQLDriver(val context: HiveContext = SparkSQLEnv.hiveCo private def getResultSetSchema(query: context.QueryExecution): Schema = { val analyzed = query.analyzed - logger.debug(s"Result Schema: ${analyzed.output}") + logDebug(s"Result Schema: ${analyzed.output}") if (analyzed.output.size == 0) { new Schema(new FieldSchema("Response code", "string", "") :: Nil, null) } else { @@ -61,7 +61,7 @@ private[hive] class SparkSQLDriver(val context: HiveContext = SparkSQLEnv.hiveCo new CommandProcessorResponse(0) } catch { case cause: Throwable => - logger.error(s"Failed in [$command]", cause) + logError(s"Failed in [$command]", cause) new CommandProcessorResponse(-3, ExceptionUtils.getFullStackTrace(cause), null) } } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala index 451c3bd7b9352..582264eb59f83 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala @@ -20,13 +20,13 @@ package org.apache.spark.sql.hive.thriftserver import org.apache.hadoop.hive.ql.session.SessionState import org.apache.spark.scheduler.{SplitInfo, StatsReportListener} -import org.apache.spark.sql.Logging +import org.apache.spark.Logging import org.apache.spark.sql.hive.HiveContext import org.apache.spark.{SparkConf, SparkContext} /** A singleton object for the master program. The slaves should not access this. */ private[hive] object SparkSQLEnv extends Logging { - logger.debug("Initializing SparkSQLEnv") + logDebug("Initializing SparkSQLEnv") var hiveContext: HiveContext = _ var sparkContext: SparkContext = _ @@ -47,7 +47,7 @@ private[hive] object SparkSQLEnv extends Logging { /** Cleans up and shuts down the Spark SQL environments. */ def stop() { - logger.debug("Shutting down Spark SQL Environment") + logDebug("Shutting down Spark SQL Environment") // Stop the SparkContext if (SparkSQLEnv.sparkContext != null) { sparkContext.stop() diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala index a4e1f3e762e89..d4dadfd21d13f 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala @@ -30,10 +30,11 @@ import org.apache.hive.service.cli._ import org.apache.hive.service.cli.operation.{ExecuteStatementOperation, Operation, OperationManager} import org.apache.hive.service.cli.session.HiveSession +import org.apache.spark.Logging import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.hive.thriftserver.ReflectionUtils import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes} -import org.apache.spark.sql.{Logging, SchemaRDD, Row => SparkRow} +import org.apache.spark.sql.{SchemaRDD, Row => SparkRow} /** * Executes queries using Spark SQL, and maintains a list of handles to active queries. @@ -55,7 +56,7 @@ class SparkSQLOperationManager(hiveContext: HiveContext) extends OperationManage def close(): Unit = { // RDDs will be cleaned automatically upon garbage collection. - logger.debug("CLOSING") + logDebug("CLOSING") } def getNextRowSet(order: FetchOrientation, maxRowsL: Long): RowSet = { @@ -112,7 +113,7 @@ class SparkSQLOperationManager(hiveContext: HiveContext) extends OperationManage } def getResultSetSchema: TableSchema = { - logger.warn(s"Result Schema: ${result.queryExecution.analyzed.output}") + logWarning(s"Result Schema: ${result.queryExecution.analyzed.output}") if (result.queryExecution.analyzed.output.size == 0) { new TableSchema(new FieldSchema("Result", "string", "") :: Nil) } else { @@ -124,11 +125,11 @@ class SparkSQLOperationManager(hiveContext: HiveContext) extends OperationManage } def run(): Unit = { - logger.info(s"Running query '$statement'") + logInfo(s"Running query '$statement'") setState(OperationState.RUNNING) try { result = hiveContext.hql(statement) - logger.debug(result.queryExecution.toString()) + logDebug(result.queryExecution.toString()) val groupId = round(random * 1000000).toString hiveContext.sparkContext.setJobGroup(groupId, statement) iter = result.queryExecution.toRdd.toLocalIterator @@ -138,7 +139,7 @@ class SparkSQLOperationManager(hiveContext: HiveContext) extends OperationManage // Actually do need to catch Throwable as some failures don't inherit from Exception and // HiveServer will silently swallow them. case e: Throwable => - logger.error("Error executing query:",e) + logError("Error executing query:",e) throw new HiveSQLException(e.toString) } setState(OperationState.FINISHED) diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala index fe3403b3292ec..b7b7c9957ac34 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala @@ -27,7 +27,7 @@ import java.sql.{Connection, DriverManager, Statement} import org.scalatest.{BeforeAndAfterAll, FunSuite} -import org.apache.spark.sql.Logging +import org.apache.spark.Logging import org.apache.spark.sql.catalyst.util.getTempFilePath /** diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 7e3b8727bebed..2c7270d9f83a9 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -207,7 +207,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { } } catch { case e: Exception => - logger.error( + logError( s""" |====================== |HIVE FAILURE OUTPUT diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index fa4e78439c26c..df3604439e483 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -28,7 +28,8 @@ import org.apache.hadoop.hive.ql.plan.TableDesc import org.apache.hadoop.hive.serde2.Deserializer import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.sql.{SQLContext, Logging} +import org.apache.spark.Logging +import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.analysis.{EliminateAnalysisOperators, Catalog} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala index c50e8c4b5c5d3..728452a25a00e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala @@ -148,7 +148,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { describedTables ++ logical.collect { case UnresolvedRelation(databaseName, name, _) => name } val referencedTestTables = referencedTables.filter(testTables.contains) - logger.debug(s"Query references test tables: ${referencedTestTables.mkString(", ")}") + logDebug(s"Query references test tables: ${referencedTestTables.mkString(", ")}") referencedTestTables.foreach(loadTestTable) // Proceed with analysis. analyzer(logical) @@ -273,7 +273,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { if (!(loadedTables contains name)) { // Marks the table as loaded first to prevent infite mutually recursive table loading. loadedTables += name - logger.info(s"Loading test table $name") + logInfo(s"Loading test table $name") val createCmds = testTables.get(name).map(_.commands).getOrElse(sys.error(s"Unknown test table $name")) createCmds.foreach(_()) @@ -312,7 +312,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { loadedTables.clear() catalog.client.getAllTables("default").foreach { t => - logger.debug(s"Deleting table $t") + logDebug(s"Deleting table $t") val table = catalog.client.getTable("default", t) catalog.client.getIndexes("default", t, 255).foreach { index => @@ -325,7 +325,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { } catalog.client.getAllDatabases.filterNot(_ == "default").foreach { db => - logger.debug(s"Dropping Database: $db") + logDebug(s"Dropping Database: $db") catalog.client.dropDatabase(db, true, false, true) } @@ -347,7 +347,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { loadTestTable("srcpart") } catch { case e: Exception => - logger.error(s"FATAL ERROR: Failed to reset TestDB state. $e") + logError(s"FATAL ERROR: Failed to reset TestDB state. $e") // At this point there is really no reason to continue, but the test framework traps exits. // So instead we just pause forever so that at least the developer can see where things // started to go wrong. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala index 7582b4743d404..d181921269b56 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala @@ -25,7 +25,7 @@ import org.apache.hadoop.hive.ql.exec.{FunctionInfo, FunctionRegistry} import org.apache.hadoop.hive.ql.udf.{UDFType => HiveUDFType} import org.apache.hadoop.hive.ql.udf.generic._ -import org.apache.spark.sql.Logging +import org.apache.spark.Logging import org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.types._ @@ -119,7 +119,7 @@ private[hive] case class HiveSimpleUdf(functionClassName: String, children: Seq[ sys.error(s"No matching wrapper found, options: ${argClass.getConstructors.toSeq}.")) (a: Any) => { - logger.debug( + logDebug( s"Wrapping $a of type ${if (a == null) "null" else a.getClass.getName} using $constructor.") // We must make sure that primitives get boxed java style. if (a == null) { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index 6c8fe4b196dea..83cfbc6b4a002 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -21,7 +21,7 @@ import java.io._ import org.scalatest.{BeforeAndAfterAll, FunSuite, GivenWhenThen} -import org.apache.spark.sql.Logging +import org.apache.spark.Logging import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.logical.{NativeCommand => LogicalNativeCommand} @@ -197,7 +197,7 @@ abstract class HiveComparisonTest // If test sharding is enable, skip tests that are not in the correct shard. shardInfo.foreach { case (shardId, numShards) if testCaseName.hashCode % numShards != shardId => return - case (shardId, _) => logger.debug(s"Shard $shardId includes test '$testCaseName'") + case (shardId, _) => logDebug(s"Shard $shardId includes test '$testCaseName'") } // Skip tests found in directories specified by user. @@ -213,13 +213,13 @@ abstract class HiveComparisonTest .map(new File(_, testCaseName)) .filter(_.exists) if (runOnlyDirectories.nonEmpty && runIndicators.isEmpty) { - logger.debug( + logDebug( s"Skipping test '$testCaseName' not found in ${runOnlyDirectories.map(_.getCanonicalPath)}") return } test(testCaseName) { - logger.debug(s"=== HIVE TEST: $testCaseName ===") + logDebug(s"=== HIVE TEST: $testCaseName ===") // Clear old output for this testcase. outputDirectories.map(new File(_, testCaseName)).filter(_.exists()).foreach(_.delete()) @@ -235,7 +235,7 @@ abstract class HiveComparisonTest .filterNot(_ contains "hive.outerjoin.supports.filters") if (allQueries != queryList) - logger.warn(s"Simplifications made on unsupported operations for test $testCaseName") + logWarning(s"Simplifications made on unsupported operations for test $testCaseName") lazy val consoleTestCase = { val quotes = "\"\"\"" @@ -257,11 +257,11 @@ abstract class HiveComparisonTest } val hiveCachedResults = hiveCacheFiles.flatMap { cachedAnswerFile => - logger.debug(s"Looking for cached answer file $cachedAnswerFile.") + logDebug(s"Looking for cached answer file $cachedAnswerFile.") if (cachedAnswerFile.exists) { Some(fileToString(cachedAnswerFile)) } else { - logger.debug(s"File $cachedAnswerFile not found") + logDebug(s"File $cachedAnswerFile not found") None } }.map { @@ -272,7 +272,7 @@ abstract class HiveComparisonTest val hiveResults: Seq[Seq[String]] = if (hiveCachedResults.size == queryList.size) { - logger.info(s"Using answer cache for test: $testCaseName") + logInfo(s"Using answer cache for test: $testCaseName") hiveCachedResults } else { @@ -287,7 +287,7 @@ abstract class HiveComparisonTest if (installHooksCommand.findAllMatchIn(queryString).nonEmpty) sys.error("hive exec hooks not supported for tests.") - logger.warn(s"Running query ${i+1}/${queryList.size} with hive.") + logWarning(s"Running query ${i+1}/${queryList.size} with hive.") // Analyze the query with catalyst to ensure test tables are loaded. val answer = hiveQuery.analyzed match { case _: ExplainCommand => Nil // No need to execute EXPLAIN queries as we don't check the output. @@ -351,7 +351,7 @@ abstract class HiveComparisonTest val resultComparison = sideBySide(hivePrintOut, catalystPrintOut).mkString("\n") if (recomputeCache) { - logger.warn(s"Clearing cache files for failed test $testCaseName") + logWarning(s"Clearing cache files for failed test $testCaseName") hiveCacheFiles.foreach(_.delete()) } @@ -380,7 +380,7 @@ abstract class HiveComparisonTest TestHive.runSqlHive("SELECT key FROM src") } catch { case e: Exception => - logger.error(s"FATAL ERROR: Canary query threw $e This implies that the testing environment has likely been corrupted.") + logError(s"FATAL ERROR: Canary query threw $e This implies that the testing environment has likely been corrupted.") // The testing setup traps exits so wait here for a long time so the developer can see when things started // to go wrong. Thread.sleep(1000000) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQueryFileTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQueryFileTest.scala index 50ab71a9003d3..02518d516261b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQueryFileTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQueryFileTest.scala @@ -53,7 +53,7 @@ abstract class HiveQueryFileTest extends HiveComparisonTest { testCases.sorted.foreach { case (testCaseName, testCaseFile) => if (blackList.map(_.r.pattern.matcher(testCaseName).matches()).reduceLeft(_||_)) { - logger.debug(s"Blacklisted test skipped $testCaseName") + logDebug(s"Blacklisted test skipped $testCaseName") } else if (realWhiteList.map(_.r.pattern.matcher(testCaseName).matches()).reduceLeft(_||_) || runAll) { // Build a test case and submit it to scala test framework... val queriesString = fileToString(testCaseFile) From 3b9f25f4259b254f3faa2a7d61e547089a69c259 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Sat, 2 Aug 2014 16:33:48 -0700 Subject: [PATCH 006/231] [SPARK-2097][SQL] UDF Support This patch adds the ability to register lambda functions written in Python, Java or Scala as UDFs for use in SQL or HiveQL. Scala: ```scala registerFunction("strLenScala", (_: String).length) sql("SELECT strLenScala('test')") ``` Python: ```python sqlCtx.registerFunction("strLenPython", lambda x: len(x), IntegerType()) sqlCtx.sql("SELECT strLenPython('test')") ``` Java: ```java sqlContext.registerFunction("stringLengthJava", new UDF1() { Override public Integer call(String str) throws Exception { return str.length(); } }, DataType.IntegerType); sqlContext.sql("SELECT stringLengthJava('test')"); ``` Author: Michael Armbrust Closes #1063 from marmbrus/udfs and squashes the following commits: 9eda0fe [Michael Armbrust] newline 747c05e [Michael Armbrust] Add some scala UDF tests. d92727d [Michael Armbrust] Merge remote-tracking branch 'apache/master' into udfs 005d684 [Michael Armbrust] Fix naming and formatting. d14dac8 [Michael Armbrust] Fix last line of autogened java files. 8135c48 [Michael Armbrust] Move UDF unit tests to pyspark. 40b0ffd [Michael Armbrust] Merge remote-tracking branch 'apache/master' into udfs 6a36890 [Michael Armbrust] Switch logging so that SQLContext can be serializable. 7a83101 [Michael Armbrust] Drop toString 795fd15 [Michael Armbrust] Try to avoid capturing SQLContext. e54fb45 [Michael Armbrust] Docs and tests. 437cbe3 [Michael Armbrust] Update use of dataTypes, fix some python tests, address review comments. 01517d6 [Michael Armbrust] Merge remote-tracking branch 'origin/master' into udfs 8e6c932 [Michael Armbrust] WIP 3f96a52 [Michael Armbrust] Merge remote-tracking branch 'origin/master' into udfs 6237c8d [Michael Armbrust] WIP 2766f0b [Michael Armbrust] Move udfs support to SQL from hive. Add support for Java UDFs. 0f7d50c [Michael Armbrust] Draft of native Spark SQL UDFs for Scala and Python. (cherry picked from commit 158ad0bba9382fd494b4789b5628a9cec00cfa19) Signed-off-by: Michael Armbrust --- python/pyspark/sql.py | 39 ++- .../catalyst/analysis/FunctionRegistry.scala | 32 ++ .../sql/catalyst/expressions/ScalaUdf.scala | 307 ++++++++++++++++++ .../org/apache/spark/sql/api/java/UDF1.java | 32 ++ .../org/apache/spark/sql/api/java/UDF10.java | 32 ++ .../org/apache/spark/sql/api/java/UDF11.java | 32 ++ .../org/apache/spark/sql/api/java/UDF12.java | 32 ++ .../org/apache/spark/sql/api/java/UDF13.java | 32 ++ .../org/apache/spark/sql/api/java/UDF14.java | 32 ++ .../org/apache/spark/sql/api/java/UDF15.java | 32 ++ .../org/apache/spark/sql/api/java/UDF16.java | 32 ++ .../org/apache/spark/sql/api/java/UDF17.java | 32 ++ .../org/apache/spark/sql/api/java/UDF18.java | 32 ++ .../org/apache/spark/sql/api/java/UDF19.java | 32 ++ .../org/apache/spark/sql/api/java/UDF2.java | 32 ++ .../org/apache/spark/sql/api/java/UDF20.java | 32 ++ .../org/apache/spark/sql/api/java/UDF21.java | 32 ++ .../org/apache/spark/sql/api/java/UDF22.java | 32 ++ .../org/apache/spark/sql/api/java/UDF3.java | 32 ++ .../org/apache/spark/sql/api/java/UDF4.java | 32 ++ .../org/apache/spark/sql/api/java/UDF5.java | 32 ++ .../org/apache/spark/sql/api/java/UDF6.java | 32 ++ .../org/apache/spark/sql/api/java/UDF7.java | 32 ++ .../org/apache/spark/sql/api/java/UDF8.java | 32 ++ .../org/apache/spark/sql/api/java/UDF9.java | 32 ++ .../org/apache/spark/sql/SQLContext.scala | 11 +- .../apache/spark/sql/UdfRegistration.scala | 196 +++++++++++ .../spark/sql/api/java/JavaSQLContext.scala | 5 +- .../spark/sql/api/java/UDFRegistration.scala | 252 ++++++++++++++ .../spark/sql/execution/SparkStrategies.scala | 2 + .../spark/sql/execution/pythonUdfs.scala | 177 ++++++++++ .../spark/sql/api/java/JavaAPISuite.java | 90 +++++ .../apache/spark/sql/InsertIntoSuite.scala | 2 +- .../scala/org/apache/spark/sql/UDFSuite.scala | 36 ++ .../apache/spark/sql/hive/HiveContext.scala | 13 +- .../org/apache/spark/sql/hive/TestHive.scala | 4 +- .../org/apache/spark/sql/hive/hiveUdfs.scala | 6 +- .../org/apache/spark/sql/QueryTest.scala | 4 +- 38 files changed, 1861 insertions(+), 19 deletions(-) create mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/UDF1.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/UDF10.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/UDF11.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/UDF12.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/UDF13.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/UDF14.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/UDF15.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/UDF16.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/UDF17.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/UDF18.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/UDF19.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/UDF2.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/UDF20.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/UDF21.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/UDF22.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/UDF3.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/UDF4.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/UDF5.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/UDF6.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/UDF7.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/UDF8.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/UDF9.java create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/api/java/UDFRegistration.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala create mode 100644 sql/core/src/test/java/org/apache/spark/sql/api/java/JavaAPISuite.java create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index f840475ffaf70..e7c35ac1ffe02 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -28,9 +28,13 @@ from operator import itemgetter from pyspark.rdd import RDD, PipelinedRDD -from pyspark.serializers import BatchedSerializer, PickleSerializer +from pyspark.serializers import BatchedSerializer, PickleSerializer, CloudPickleSerializer + +from itertools import chain, ifilter, imap from py4j.protocol import Py4JError +from py4j.java_collections import ListConverter, MapConverter + __all__ = [ "StringType", "BinaryType", "BooleanType", "TimestampType", "DecimalType", @@ -932,6 +936,39 @@ def _ssql_ctx(self): self._scala_SQLContext = self._jvm.SQLContext(self._jsc.sc()) return self._scala_SQLContext + def registerFunction(self, name, f, returnType=StringType()): + """Registers a lambda function as a UDF so it can be used in SQL statements. + + In addition to a name and the function itself, the return type can be optionally specified. + When the return type is not given it default to a string and conversion will automatically + be done. For any other return type, the produced object must match the specified type. + + >>> sqlCtx.registerFunction("stringLengthString", lambda x: len(x)) + >>> sqlCtx.sql("SELECT stringLengthString('test')").collect() + [Row(c0=u'4')] + >>> sqlCtx.registerFunction("stringLengthInt", lambda x: len(x), IntegerType()) + >>> sqlCtx.sql("SELECT stringLengthInt('test')").collect() + [Row(c0=4)] + >>> sqlCtx.registerFunction("twoArgs", lambda x, y: len(x) + y, IntegerType()) + >>> sqlCtx.sql("SELECT twoArgs('test', 1)").collect() + [Row(c0=5)] + """ + func = lambda _, it: imap(lambda x: f(*x), it) + command = (func, + BatchedSerializer(PickleSerializer(), 1024), + BatchedSerializer(PickleSerializer(), 1024)) + env = MapConverter().convert(self._sc.environment, + self._sc._gateway._gateway_client) + includes = ListConverter().convert(self._sc._python_includes, + self._sc._gateway._gateway_client) + self._ssql_ctx.registerPython(name, + bytearray(CloudPickleSerializer().dumps(command)), + env, + includes, + self._sc.pythonExec, + self._sc._javaAccumulator, + str(returnType)) + def inferSchema(self, rdd): """Infer and apply a schema to an RDD of L{Row}s. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index c0255701b7ba5..760c49fbca4a5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -18,17 +18,49 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.expressions.Expression +import scala.collection.mutable /** A catalog for looking up user defined functions, used by an [[Analyzer]]. */ trait FunctionRegistry { + type FunctionBuilder = Seq[Expression] => Expression + + def registerFunction(name: String, builder: FunctionBuilder): Unit + def lookupFunction(name: String, children: Seq[Expression]): Expression } +trait OverrideFunctionRegistry extends FunctionRegistry { + + val functionBuilders = new mutable.HashMap[String, FunctionBuilder]() + + def registerFunction(name: String, builder: FunctionBuilder) = { + functionBuilders.put(name, builder) + } + + abstract override def lookupFunction(name: String, children: Seq[Expression]): Expression = { + functionBuilders.get(name).map(_(children)).getOrElse(super.lookupFunction(name,children)) + } +} + +class SimpleFunctionRegistry extends FunctionRegistry { + val functionBuilders = new mutable.HashMap[String, FunctionBuilder]() + + def registerFunction(name: String, builder: FunctionBuilder) = { + functionBuilders.put(name, builder) + } + + override def lookupFunction(name: String, children: Seq[Expression]): Expression = { + functionBuilders(name)(children) + } +} + /** * A trivial catalog that returns an error when a function is requested. Used for testing when all * functions are already filled in and the analyser needs only to resolve attribute references. */ object EmptyFunctionRegistry extends FunctionRegistry { + def registerFunction(name: String, builder: FunctionBuilder) = ??? + def lookupFunction(name: String, children: Seq[Expression]): Expression = { throw new UnsupportedOperationException } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala index acddf5e9c7004..95633dd0c9870 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala @@ -27,6 +27,22 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi def references = children.flatMap(_.references).toSet def nullable = true + /** This method has been generated by this script + + (1 to 22).map { x => + val anys = (1 to x).map(x => "Any").reduce(_ + ", " + _) + val evals = (0 to x - 1).map(x => s"children($x).eval(input)").reduce(_ + ",\n " + _) + + s""" + case $x => + function.asInstanceOf[($anys) => Any]( + $evals) + """ + } + + */ + + // scalastyle:off override def eval(input: Row): Any = { children.size match { case 0 => function.asInstanceOf[() => Any]() @@ -35,6 +51,297 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi function.asInstanceOf[(Any, Any) => Any]( children(0).eval(input), children(1).eval(input)) + case 3 => + function.asInstanceOf[(Any, Any, Any) => Any]( + children(0).eval(input), + children(1).eval(input), + children(2).eval(input)) + case 4 => + function.asInstanceOf[(Any, Any, Any, Any) => Any]( + children(0).eval(input), + children(1).eval(input), + children(2).eval(input), + children(3).eval(input)) + case 5 => + function.asInstanceOf[(Any, Any, Any, Any, Any) => Any]( + children(0).eval(input), + children(1).eval(input), + children(2).eval(input), + children(3).eval(input), + children(4).eval(input)) + case 6 => + function.asInstanceOf[(Any, Any, Any, Any, Any, Any) => Any]( + children(0).eval(input), + children(1).eval(input), + children(2).eval(input), + children(3).eval(input), + children(4).eval(input), + children(5).eval(input)) + case 7 => + function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any) => Any]( + children(0).eval(input), + children(1).eval(input), + children(2).eval(input), + children(3).eval(input), + children(4).eval(input), + children(5).eval(input), + children(6).eval(input)) + case 8 => + function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any) => Any]( + children(0).eval(input), + children(1).eval(input), + children(2).eval(input), + children(3).eval(input), + children(4).eval(input), + children(5).eval(input), + children(6).eval(input), + children(7).eval(input)) + case 9 => + function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( + children(0).eval(input), + children(1).eval(input), + children(2).eval(input), + children(3).eval(input), + children(4).eval(input), + children(5).eval(input), + children(6).eval(input), + children(7).eval(input), + children(8).eval(input)) + case 10 => + function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( + children(0).eval(input), + children(1).eval(input), + children(2).eval(input), + children(3).eval(input), + children(4).eval(input), + children(5).eval(input), + children(6).eval(input), + children(7).eval(input), + children(8).eval(input), + children(9).eval(input)) + case 11 => + function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( + children(0).eval(input), + children(1).eval(input), + children(2).eval(input), + children(3).eval(input), + children(4).eval(input), + children(5).eval(input), + children(6).eval(input), + children(7).eval(input), + children(8).eval(input), + children(9).eval(input), + children(10).eval(input)) + case 12 => + function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( + children(0).eval(input), + children(1).eval(input), + children(2).eval(input), + children(3).eval(input), + children(4).eval(input), + children(5).eval(input), + children(6).eval(input), + children(7).eval(input), + children(8).eval(input), + children(9).eval(input), + children(10).eval(input), + children(11).eval(input)) + case 13 => + function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( + children(0).eval(input), + children(1).eval(input), + children(2).eval(input), + children(3).eval(input), + children(4).eval(input), + children(5).eval(input), + children(6).eval(input), + children(7).eval(input), + children(8).eval(input), + children(9).eval(input), + children(10).eval(input), + children(11).eval(input), + children(12).eval(input)) + case 14 => + function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( + children(0).eval(input), + children(1).eval(input), + children(2).eval(input), + children(3).eval(input), + children(4).eval(input), + children(5).eval(input), + children(6).eval(input), + children(7).eval(input), + children(8).eval(input), + children(9).eval(input), + children(10).eval(input), + children(11).eval(input), + children(12).eval(input), + children(13).eval(input)) + case 15 => + function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( + children(0).eval(input), + children(1).eval(input), + children(2).eval(input), + children(3).eval(input), + children(4).eval(input), + children(5).eval(input), + children(6).eval(input), + children(7).eval(input), + children(8).eval(input), + children(9).eval(input), + children(10).eval(input), + children(11).eval(input), + children(12).eval(input), + children(13).eval(input), + children(14).eval(input)) + case 16 => + function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( + children(0).eval(input), + children(1).eval(input), + children(2).eval(input), + children(3).eval(input), + children(4).eval(input), + children(5).eval(input), + children(6).eval(input), + children(7).eval(input), + children(8).eval(input), + children(9).eval(input), + children(10).eval(input), + children(11).eval(input), + children(12).eval(input), + children(13).eval(input), + children(14).eval(input), + children(15).eval(input)) + case 17 => + function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( + children(0).eval(input), + children(1).eval(input), + children(2).eval(input), + children(3).eval(input), + children(4).eval(input), + children(5).eval(input), + children(6).eval(input), + children(7).eval(input), + children(8).eval(input), + children(9).eval(input), + children(10).eval(input), + children(11).eval(input), + children(12).eval(input), + children(13).eval(input), + children(14).eval(input), + children(15).eval(input), + children(16).eval(input)) + case 18 => + function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( + children(0).eval(input), + children(1).eval(input), + children(2).eval(input), + children(3).eval(input), + children(4).eval(input), + children(5).eval(input), + children(6).eval(input), + children(7).eval(input), + children(8).eval(input), + children(9).eval(input), + children(10).eval(input), + children(11).eval(input), + children(12).eval(input), + children(13).eval(input), + children(14).eval(input), + children(15).eval(input), + children(16).eval(input), + children(17).eval(input)) + case 19 => + function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( + children(0).eval(input), + children(1).eval(input), + children(2).eval(input), + children(3).eval(input), + children(4).eval(input), + children(5).eval(input), + children(6).eval(input), + children(7).eval(input), + children(8).eval(input), + children(9).eval(input), + children(10).eval(input), + children(11).eval(input), + children(12).eval(input), + children(13).eval(input), + children(14).eval(input), + children(15).eval(input), + children(16).eval(input), + children(17).eval(input), + children(18).eval(input)) + case 20 => + function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( + children(0).eval(input), + children(1).eval(input), + children(2).eval(input), + children(3).eval(input), + children(4).eval(input), + children(5).eval(input), + children(6).eval(input), + children(7).eval(input), + children(8).eval(input), + children(9).eval(input), + children(10).eval(input), + children(11).eval(input), + children(12).eval(input), + children(13).eval(input), + children(14).eval(input), + children(15).eval(input), + children(16).eval(input), + children(17).eval(input), + children(18).eval(input), + children(19).eval(input)) + case 21 => + function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( + children(0).eval(input), + children(1).eval(input), + children(2).eval(input), + children(3).eval(input), + children(4).eval(input), + children(5).eval(input), + children(6).eval(input), + children(7).eval(input), + children(8).eval(input), + children(9).eval(input), + children(10).eval(input), + children(11).eval(input), + children(12).eval(input), + children(13).eval(input), + children(14).eval(input), + children(15).eval(input), + children(16).eval(input), + children(17).eval(input), + children(18).eval(input), + children(19).eval(input), + children(20).eval(input)) + case 22 => + function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( + children(0).eval(input), + children(1).eval(input), + children(2).eval(input), + children(3).eval(input), + children(4).eval(input), + children(5).eval(input), + children(6).eval(input), + children(7).eval(input), + children(8).eval(input), + children(9).eval(input), + children(10).eval(input), + children(11).eval(input), + children(12).eval(input), + children(13).eval(input), + children(14).eval(input), + children(15).eval(input), + children(16).eval(input), + children(17).eval(input), + children(18).eval(input), + children(19).eval(input), + children(20).eval(input), + children(21).eval(input)) } + // scalastyle:on } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF1.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF1.java new file mode 100644 index 0000000000000..ef959e35e1027 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF1.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +import java.io.Serializable; + +// ************************************************** +// THIS FILE IS AUTOGENERATED BY CODE IN +// org.apache.spark.sql.api.java.FunctionRegistration +// ************************************************** + +/** + * A Spark SQL UDF that has 1 arguments. + */ +public interface UDF1 extends Serializable { + public R call(T1 t1) throws Exception; +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF10.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF10.java new file mode 100644 index 0000000000000..96ab3a96c3d5e --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF10.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +import java.io.Serializable; + +// ************************************************** +// THIS FILE IS AUTOGENERATED BY CODE IN +// org.apache.spark.sql.api.java.FunctionRegistration +// ************************************************** + +/** + * A Spark SQL UDF that has 10 arguments. + */ +public interface UDF10 extends Serializable { + public R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10) throws Exception; +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF11.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF11.java new file mode 100644 index 0000000000000..58ae8edd6d817 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF11.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +import java.io.Serializable; + +// ************************************************** +// THIS FILE IS AUTOGENERATED BY CODE IN +// org.apache.spark.sql.api.java.FunctionRegistration +// ************************************************** + +/** + * A Spark SQL UDF that has 11 arguments. + */ +public interface UDF11 extends Serializable { + public R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11) throws Exception; +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF12.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF12.java new file mode 100644 index 0000000000000..d9da0f6eddd94 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF12.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +import java.io.Serializable; + +// ************************************************** +// THIS FILE IS AUTOGENERATED BY CODE IN +// org.apache.spark.sql.api.java.FunctionRegistration +// ************************************************** + +/** + * A Spark SQL UDF that has 12 arguments. + */ +public interface UDF12 extends Serializable { + public R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12) throws Exception; +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF13.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF13.java new file mode 100644 index 0000000000000..095fc1a8076b5 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF13.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +import java.io.Serializable; + +// ************************************************** +// THIS FILE IS AUTOGENERATED BY CODE IN +// org.apache.spark.sql.api.java.FunctionRegistration +// ************************************************** + +/** + * A Spark SQL UDF that has 13 arguments. + */ +public interface UDF13 extends Serializable { + public R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12, T13 t13) throws Exception; +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF14.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF14.java new file mode 100644 index 0000000000000..eb27eaa180086 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF14.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +import java.io.Serializable; + +// ************************************************** +// THIS FILE IS AUTOGENERATED BY CODE IN +// org.apache.spark.sql.api.java.FunctionRegistration +// ************************************************** + +/** + * A Spark SQL UDF that has 14 arguments. + */ +public interface UDF14 extends Serializable { + public R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12, T13 t13, T14 t14) throws Exception; +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF15.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF15.java new file mode 100644 index 0000000000000..1fbcff56332b6 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF15.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +import java.io.Serializable; + +// ************************************************** +// THIS FILE IS AUTOGENERATED BY CODE IN +// org.apache.spark.sql.api.java.FunctionRegistration +// ************************************************** + +/** + * A Spark SQL UDF that has 15 arguments. + */ +public interface UDF15 extends Serializable { + public R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12, T13 t13, T14 t14, T15 t15) throws Exception; +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF16.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF16.java new file mode 100644 index 0000000000000..1133561787a69 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF16.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +import java.io.Serializable; + +// ************************************************** +// THIS FILE IS AUTOGENERATED BY CODE IN +// org.apache.spark.sql.api.java.FunctionRegistration +// ************************************************** + +/** + * A Spark SQL UDF that has 16 arguments. + */ +public interface UDF16 extends Serializable { + public R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12, T13 t13, T14 t14, T15 t15, T16 t16) throws Exception; +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF17.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF17.java new file mode 100644 index 0000000000000..dfae7922c9b63 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF17.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +import java.io.Serializable; + +// ************************************************** +// THIS FILE IS AUTOGENERATED BY CODE IN +// org.apache.spark.sql.api.java.FunctionRegistration +// ************************************************** + +/** + * A Spark SQL UDF that has 17 arguments. + */ +public interface UDF17 extends Serializable { + public R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12, T13 t13, T14 t14, T15 t15, T16 t16, T17 t17) throws Exception; +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF18.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF18.java new file mode 100644 index 0000000000000..e9d1c6d52d4ea --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF18.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +import java.io.Serializable; + +// ************************************************** +// THIS FILE IS AUTOGENERATED BY CODE IN +// org.apache.spark.sql.api.java.FunctionRegistration +// ************************************************** + +/** + * A Spark SQL UDF that has 18 arguments. + */ +public interface UDF18 extends Serializable { + public R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12, T13 t13, T14 t14, T15 t15, T16 t16, T17 t17, T18 t18) throws Exception; +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF19.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF19.java new file mode 100644 index 0000000000000..46b9d2d3c9457 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF19.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +import java.io.Serializable; + +// ************************************************** +// THIS FILE IS AUTOGENERATED BY CODE IN +// org.apache.spark.sql.api.java.FunctionRegistration +// ************************************************** + +/** + * A Spark SQL UDF that has 19 arguments. + */ +public interface UDF19 extends Serializable { + public R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12, T13 t13, T14 t14, T15 t15, T16 t16, T17 t17, T18 t18, T19 t19) throws Exception; +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF2.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF2.java new file mode 100644 index 0000000000000..cd3fde8da419e --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF2.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +import java.io.Serializable; + +// ************************************************** +// THIS FILE IS AUTOGENERATED BY CODE IN +// org.apache.spark.sql.api.java.FunctionRegistration +// ************************************************** + +/** + * A Spark SQL UDF that has 2 arguments. + */ +public interface UDF2 extends Serializable { + public R call(T1 t1, T2 t2) throws Exception; +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF20.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF20.java new file mode 100644 index 0000000000000..113d3d26be4a7 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF20.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +import java.io.Serializable; + +// ************************************************** +// THIS FILE IS AUTOGENERATED BY CODE IN +// org.apache.spark.sql.api.java.FunctionRegistration +// ************************************************** + +/** + * A Spark SQL UDF that has 20 arguments. + */ +public interface UDF20 extends Serializable { + public R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12, T13 t13, T14 t14, T15 t15, T16 t16, T17 t17, T18 t18, T19 t19, T20 t20) throws Exception; +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF21.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF21.java new file mode 100644 index 0000000000000..74118f2cf8da7 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF21.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +import java.io.Serializable; + +// ************************************************** +// THIS FILE IS AUTOGENERATED BY CODE IN +// org.apache.spark.sql.api.java.FunctionRegistration +// ************************************************** + +/** + * A Spark SQL UDF that has 21 arguments. + */ +public interface UDF21 extends Serializable { + public R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12, T13 t13, T14 t14, T15 t15, T16 t16, T17 t17, T18 t18, T19 t19, T20 t20, T21 t21) throws Exception; +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF22.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF22.java new file mode 100644 index 0000000000000..0e7cc40be45ec --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF22.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +import java.io.Serializable; + +// ************************************************** +// THIS FILE IS AUTOGENERATED BY CODE IN +// org.apache.spark.sql.api.java.FunctionRegistration +// ************************************************** + +/** + * A Spark SQL UDF that has 22 arguments. + */ +public interface UDF22 extends Serializable { + public R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12, T13 t13, T14 t14, T15 t15, T16 t16, T17 t17, T18 t18, T19 t19, T20 t20, T21 t21, T22 t22) throws Exception; +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF3.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF3.java new file mode 100644 index 0000000000000..6a880f16be47a --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF3.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +import java.io.Serializable; + +// ************************************************** +// THIS FILE IS AUTOGENERATED BY CODE IN +// org.apache.spark.sql.api.java.FunctionRegistration +// ************************************************** + +/** + * A Spark SQL UDF that has 3 arguments. + */ +public interface UDF3 extends Serializable { + public R call(T1 t1, T2 t2, T3 t3) throws Exception; +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF4.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF4.java new file mode 100644 index 0000000000000..fcad2febb18e6 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF4.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +import java.io.Serializable; + +// ************************************************** +// THIS FILE IS AUTOGENERATED BY CODE IN +// org.apache.spark.sql.api.java.FunctionRegistration +// ************************************************** + +/** + * A Spark SQL UDF that has 4 arguments. + */ +public interface UDF4 extends Serializable { + public R call(T1 t1, T2 t2, T3 t3, T4 t4) throws Exception; +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF5.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF5.java new file mode 100644 index 0000000000000..ce0cef43a2144 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF5.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +import java.io.Serializable; + +// ************************************************** +// THIS FILE IS AUTOGENERATED BY CODE IN +// org.apache.spark.sql.api.java.FunctionRegistration +// ************************************************** + +/** + * A Spark SQL UDF that has 5 arguments. + */ +public interface UDF5 extends Serializable { + public R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5) throws Exception; +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF6.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF6.java new file mode 100644 index 0000000000000..f56b806684e61 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF6.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +import java.io.Serializable; + +// ************************************************** +// THIS FILE IS AUTOGENERATED BY CODE IN +// org.apache.spark.sql.api.java.FunctionRegistration +// ************************************************** + +/** + * A Spark SQL UDF that has 6 arguments. + */ +public interface UDF6 extends Serializable { + public R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6) throws Exception; +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF7.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF7.java new file mode 100644 index 0000000000000..25bd6d3241bd4 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF7.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +import java.io.Serializable; + +// ************************************************** +// THIS FILE IS AUTOGENERATED BY CODE IN +// org.apache.spark.sql.api.java.FunctionRegistration +// ************************************************** + +/** + * A Spark SQL UDF that has 7 arguments. + */ +public interface UDF7 extends Serializable { + public R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7) throws Exception; +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF8.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF8.java new file mode 100644 index 0000000000000..a3b7ac5f94ce7 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF8.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +import java.io.Serializable; + +// ************************************************** +// THIS FILE IS AUTOGENERATED BY CODE IN +// org.apache.spark.sql.api.java.FunctionRegistration +// ************************************************** + +/** + * A Spark SQL UDF that has 8 arguments. + */ +public interface UDF8 extends Serializable { + public R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8) throws Exception; +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF9.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF9.java new file mode 100644 index 0000000000000..205e72a1522fc --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF9.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +import java.io.Serializable; + +// ************************************************** +// THIS FILE IS AUTOGENERATED BY CODE IN +// org.apache.spark.sql.api.java.FunctionRegistration +// ************************************************** + +/** + * A Spark SQL UDF that has 9 arguments. + */ +public interface UDF9 extends Serializable { + public R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9) throws Exception; +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 00dd34aabc389..33931e5d996f5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -48,18 +48,23 @@ import org.apache.spark.{Logging, SparkContext} */ @AlphaComponent class SQLContext(@transient val sparkContext: SparkContext) - extends Logging + extends org.apache.spark.Logging with SQLConf with ExpressionConversions + with UDFRegistration with Serializable { self => @transient protected[sql] lazy val catalog: Catalog = new SimpleCatalog(true) + + @transient + protected[sql] lazy val functionRegistry: FunctionRegistry = new SimpleFunctionRegistry + @transient protected[sql] lazy val analyzer: Analyzer = - new Analyzer(catalog, EmptyFunctionRegistry, caseSensitive = true) + new Analyzer(catalog, functionRegistry, caseSensitive = true) @transient protected[sql] val optimizer = Optimizer @transient @@ -379,7 +384,7 @@ class SQLContext(@transient val sparkContext: SparkContext) protected abstract class QueryExecution { def logical: LogicalPlan - lazy val analyzed = analyzer(logical) + lazy val analyzed = ExtractPythonUdfs(analyzer(logical)) lazy val optimizedPlan = optimizer(analyzed) // TODO: Don't just pick the first one... lazy val sparkPlan = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala new file mode 100644 index 0000000000000..0b48e9e659faa --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala @@ -0,0 +1,196 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.util.{List => JList, Map => JMap} + +import org.apache.spark.Accumulator +import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUdf} +import org.apache.spark.sql.execution.PythonUDF + +import scala.reflect.runtime.universe.{TypeTag, typeTag} + +/** + * Functions for registering scala lambda functions as UDFs in a SQLContext. + */ +protected[sql] trait UDFRegistration { + self: SQLContext => + + private[spark] def registerPython( + name: String, + command: Array[Byte], + envVars: JMap[String, String], + pythonIncludes: JList[String], + pythonExec: String, + accumulator: Accumulator[JList[Array[Byte]]], + stringDataType: String): Unit = { + log.debug( + s""" + | Registering new PythonUDF: + | name: $name + | command: ${command.toSeq} + | envVars: $envVars + | pythonIncludes: $pythonIncludes + | pythonExec: $pythonExec + | dataType: $stringDataType + """.stripMargin) + + + val dataType = parseDataType(stringDataType) + + def builder(e: Seq[Expression]) = + PythonUDF( + name, + command, + envVars, + pythonIncludes, + pythonExec, + accumulator, + dataType, + e) + + functionRegistry.registerFunction(name, builder) + } + + /** registerFunction 1-22 were generated by this script + + (1 to 22).map { x => + val types = (1 to x).map(x => "_").reduce(_ + ", " + _) + s""" + def registerFunction[T: TypeTag](name: String, func: Function$x[$types, T]): Unit = { + def builder(e: Seq[Expression]) = + ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + functionRegistry.registerFunction(name, builder) + } + """ + } + */ + + // scalastyle:off + def registerFunction[T: TypeTag](name: String, func: Function1[_, T]): Unit = { + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + functionRegistry.registerFunction(name, builder) + } + + def registerFunction[T: TypeTag](name: String, func: Function2[_, _, T]): Unit = { + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + functionRegistry.registerFunction(name, builder) + } + + def registerFunction[T: TypeTag](name: String, func: Function3[_, _, _, T]): Unit = { + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + functionRegistry.registerFunction(name, builder) + } + + def registerFunction[T: TypeTag](name: String, func: Function4[_, _, _, _, T]): Unit = { + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + functionRegistry.registerFunction(name, builder) + } + + def registerFunction[T: TypeTag](name: String, func: Function5[_, _, _, _, _, T]): Unit = { + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + functionRegistry.registerFunction(name, builder) + } + + def registerFunction[T: TypeTag](name: String, func: Function6[_, _, _, _, _, _, T]): Unit = { + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + functionRegistry.registerFunction(name, builder) + } + + def registerFunction[T: TypeTag](name: String, func: Function7[_, _, _, _, _, _, _, T]): Unit = { + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + functionRegistry.registerFunction(name, builder) + } + + def registerFunction[T: TypeTag](name: String, func: Function8[_, _, _, _, _, _, _, _, T]): Unit = { + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + functionRegistry.registerFunction(name, builder) + } + + def registerFunction[T: TypeTag](name: String, func: Function9[_, _, _, _, _, _, _, _, _, T]): Unit = { + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + functionRegistry.registerFunction(name, builder) + } + + def registerFunction[T: TypeTag](name: String, func: Function10[_, _, _, _, _, _, _, _, _, _, T]): Unit = { + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + functionRegistry.registerFunction(name, builder) + } + + def registerFunction[T: TypeTag](name: String, func: Function11[_, _, _, _, _, _, _, _, _, _, _, T]): Unit = { + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + functionRegistry.registerFunction(name, builder) + } + + def registerFunction[T: TypeTag](name: String, func: Function12[_, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = { + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + functionRegistry.registerFunction(name, builder) + } + + def registerFunction[T: TypeTag](name: String, func: Function13[_, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = { + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + functionRegistry.registerFunction(name, builder) + } + + def registerFunction[T: TypeTag](name: String, func: Function14[_, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = { + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + functionRegistry.registerFunction(name, builder) + } + + def registerFunction[T: TypeTag](name: String, func: Function15[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = { + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + functionRegistry.registerFunction(name, builder) + } + + def registerFunction[T: TypeTag](name: String, func: Function16[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = { + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + functionRegistry.registerFunction(name, builder) + } + + def registerFunction[T: TypeTag](name: String, func: Function17[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = { + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + functionRegistry.registerFunction(name, builder) + } + + def registerFunction[T: TypeTag](name: String, func: Function18[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = { + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + functionRegistry.registerFunction(name, builder) + } + + def registerFunction[T: TypeTag](name: String, func: Function19[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = { + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + functionRegistry.registerFunction(name, builder) + } + + def registerFunction[T: TypeTag](name: String, func: Function20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = { + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + functionRegistry.registerFunction(name, builder) + } + + def registerFunction[T: TypeTag](name: String, func: Function21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = { + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + functionRegistry.registerFunction(name, builder) + } + + def registerFunction[T: TypeTag](name: String, func: Function22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = { + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + functionRegistry.registerFunction(name, builder) + } + // scalastyle:on +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala index 809dd038f94aa..ae45193ed15d3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala @@ -28,14 +28,13 @@ import org.apache.spark.sql.{SQLContext, StructType => SStructType} import org.apache.spark.sql.catalyst.expressions.{AttributeReference, GenericRow, Row => ScalaRow} import org.apache.spark.sql.parquet.ParquetRelation import org.apache.spark.sql.execution.{ExistingRdd, SparkLogicalPlan} -import org.apache.spark.sql.types.util.DataTypeConversions -import DataTypeConversions.asScalaDataType; +import org.apache.spark.sql.types.util.DataTypeConversions.asScalaDataType import org.apache.spark.util.Utils /** * The entry point for executing Spark SQL queries from a Java program. */ -class JavaSQLContext(val sqlContext: SQLContext) { +class JavaSQLContext(val sqlContext: SQLContext) extends UDFRegistration { def this(sparkContext: JavaSparkContext) = this(new SQLContext(sparkContext.sc)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/UDFRegistration.scala new file mode 100644 index 0000000000000..158f26e3d445f --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/UDFRegistration.scala @@ -0,0 +1,252 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.api.java + +import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUdf} +import org.apache.spark.sql.types.util.DataTypeConversions._ + +/** + * A collection of functions that allow Java users to register UDFs. In order to handle functions + * of varying airities with minimal boilerplate for our users, we generate classes and functions + * for each airity up to 22. The code for this generation can be found in comments in this trait. + */ +private[java] trait UDFRegistration { + self: JavaSQLContext => + + /* The following functions and required interfaces are generated with these code fragments: + + (1 to 22).foreach { i => + val extTypeArgs = (1 to i).map(_ => "_").mkString(", ") + val anyTypeArgs = (1 to i).map(_ => "Any").mkString(", ") + val anyCast = s".asInstanceOf[UDF$i[$anyTypeArgs, Any]]" + val anyParams = (1 to i).map(_ => "_: Any").mkString(", ") + println(s""" + |def registerFunction( + | name: String, f: UDF$i[$extTypeArgs, _], @transient dataType: DataType) = { + | val scalaType = asScalaDataType(dataType) + | sqlContext.functionRegistry.registerFunction( + | name, + | (e: Seq[Expression]) => ScalaUdf(f$anyCast.call($anyParams), scalaType, e)) + |} + """.stripMargin) + } + + import java.io.File + import org.apache.spark.sql.catalyst.util.stringToFile + val directory = new File("sql/core/src/main/java/org/apache/spark/sql/api/java/") + (1 to 22).foreach { i => + val typeArgs = (1 to i).map(i => s"T$i").mkString(", ") + val args = (1 to i).map(i => s"T$i t$i").mkString(", ") + + val contents = + s"""/* + | * Licensed to the Apache Software Foundation (ASF) under one or more + | * contributor license agreements. See the NOTICE file distributed with + | * this work for additional information regarding copyright ownership. + | * The ASF licenses this file to You under the Apache License, Version 2.0 + | * (the "License"); you may not use this file except in compliance with + | * the License. You may obtain a copy of the License at + | * + | * http://www.apache.org/licenses/LICENSE-2.0 + | * + | * Unless required by applicable law or agreed to in writing, software + | * distributed under the License is distributed on an "AS IS" BASIS, + | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + | * See the License for the specific language governing permissions and + | * limitations under the License. + | */ + | + |package org.apache.spark.sql.api.java; + | + |import java.io.Serializable; + | + |// ************************************************** + |// THIS FILE IS AUTOGENERATED BY CODE IN + |// org.apache.spark.sql.api.java.FunctionRegistration + |// ************************************************** + | + |/** + | * A Spark SQL UDF that has $i arguments. + | */ + |public interface UDF$i<$typeArgs, R> extends Serializable { + | public R call($args) throws Exception; + |} + |""".stripMargin + + stringToFile(new File(directory, s"UDF$i.java"), contents) + } + + */ + + // scalastyle:off + def registerFunction(name: String, f: UDF1[_, _], dataType: DataType) = { + val scalaType = asScalaDataType(dataType) + sqlContext.functionRegistry.registerFunction( + name, + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF1[Any, Any]].call(_: Any), scalaType, e)) + } + + def registerFunction(name: String, f: UDF2[_, _, _], dataType: DataType) = { + val scalaType = asScalaDataType(dataType) + sqlContext.functionRegistry.registerFunction( + name, + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF2[Any, Any, Any]].call(_: Any, _: Any), scalaType, e)) + } + + def registerFunction(name: String, f: UDF3[_, _, _, _], dataType: DataType) = { + val scalaType = asScalaDataType(dataType) + sqlContext.functionRegistry.registerFunction( + name, + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF3[Any, Any, Any, Any]].call(_: Any, _: Any, _: Any), scalaType, e)) + } + + def registerFunction(name: String, f: UDF4[_, _, _, _, _], dataType: DataType) = { + val scalaType = asScalaDataType(dataType) + sqlContext.functionRegistry.registerFunction( + name, + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF4[Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any), scalaType, e)) + } + + def registerFunction(name: String, f: UDF5[_, _, _, _, _, _], dataType: DataType) = { + val scalaType = asScalaDataType(dataType) + sqlContext.functionRegistry.registerFunction( + name, + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF5[Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any), scalaType, e)) + } + + def registerFunction(name: String, f: UDF6[_, _, _, _, _, _, _], dataType: DataType) = { + val scalaType = asScalaDataType(dataType) + sqlContext.functionRegistry.registerFunction( + name, + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF6[Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any), scalaType, e)) + } + + def registerFunction(name: String, f: UDF7[_, _, _, _, _, _, _, _], dataType: DataType) = { + val scalaType = asScalaDataType(dataType) + sqlContext.functionRegistry.registerFunction( + name, + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF7[Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), scalaType, e)) + } + + def registerFunction(name: String, f: UDF8[_, _, _, _, _, _, _, _, _], dataType: DataType) = { + val scalaType = asScalaDataType(dataType) + sqlContext.functionRegistry.registerFunction( + name, + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF8[Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), scalaType, e)) + } + + def registerFunction(name: String, f: UDF9[_, _, _, _, _, _, _, _, _, _], dataType: DataType) = { + val scalaType = asScalaDataType(dataType) + sqlContext.functionRegistry.registerFunction( + name, + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF9[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), scalaType, e)) + } + + def registerFunction(name: String, f: UDF10[_, _, _, _, _, _, _, _, _, _, _], dataType: DataType) = { + val scalaType = asScalaDataType(dataType) + sqlContext.functionRegistry.registerFunction( + name, + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF10[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), scalaType, e)) + } + + def registerFunction(name: String, f: UDF11[_, _, _, _, _, _, _, _, _, _, _, _], dataType: DataType) = { + val scalaType = asScalaDataType(dataType) + sqlContext.functionRegistry.registerFunction( + name, + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF11[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), scalaType, e)) + } + + def registerFunction(name: String, f: UDF12[_, _, _, _, _, _, _, _, _, _, _, _, _], dataType: DataType) = { + val scalaType = asScalaDataType(dataType) + sqlContext.functionRegistry.registerFunction( + name, + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF12[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), scalaType, e)) + } + + def registerFunction(name: String, f: UDF13[_, _, _, _, _, _, _, _, _, _, _, _, _, _], dataType: DataType) = { + val scalaType = asScalaDataType(dataType) + sqlContext.functionRegistry.registerFunction( + name, + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF13[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), scalaType, e)) + } + + def registerFunction(name: String, f: UDF14[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _], dataType: DataType) = { + val scalaType = asScalaDataType(dataType) + sqlContext.functionRegistry.registerFunction( + name, + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF14[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), scalaType, e)) + } + + def registerFunction(name: String, f: UDF15[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], dataType: DataType) = { + val scalaType = asScalaDataType(dataType) + sqlContext.functionRegistry.registerFunction( + name, + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF15[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), scalaType, e)) + } + + def registerFunction(name: String, f: UDF16[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], dataType: DataType) = { + val scalaType = asScalaDataType(dataType) + sqlContext.functionRegistry.registerFunction( + name, + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF16[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), scalaType, e)) + } + + def registerFunction(name: String, f: UDF17[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], dataType: DataType) = { + val scalaType = asScalaDataType(dataType) + sqlContext.functionRegistry.registerFunction( + name, + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF17[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), scalaType, e)) + } + + def registerFunction(name: String, f: UDF18[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], dataType: DataType) = { + val scalaType = asScalaDataType(dataType) + sqlContext.functionRegistry.registerFunction( + name, + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF18[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), scalaType, e)) + } + + def registerFunction(name: String, f: UDF19[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], dataType: DataType) = { + val scalaType = asScalaDataType(dataType) + sqlContext.functionRegistry.registerFunction( + name, + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF19[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), scalaType, e)) + } + + def registerFunction(name: String, f: UDF20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], dataType: DataType) = { + val scalaType = asScalaDataType(dataType) + sqlContext.functionRegistry.registerFunction( + name, + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF20[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), scalaType, e)) + } + + def registerFunction(name: String, f: UDF21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], dataType: DataType) = { + val scalaType = asScalaDataType(dataType) + sqlContext.functionRegistry.registerFunction( + name, + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF21[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), scalaType, e)) + } + + def registerFunction(name: String, f: UDF22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], dataType: DataType) = { + val scalaType = asScalaDataType(dataType) + sqlContext.functionRegistry.registerFunction( + name, + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF22[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), scalaType, e)) + } + + // scalastyle:on +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 8bec015c7b465..f0c958fdb537f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -286,6 +286,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.ExistingRdd(Nil, singleRowRdd) :: Nil case logical.Repartition(expressions, child) => execution.Exchange(HashPartitioning(expressions, numPartitions), planLater(child)) :: Nil + case e @ EvaluatePython(udf, child) => + BatchPythonEvaluation(udf, e.output, planLater(child)) :: Nil case SparkLogicalPlan(existingPlan) => existingPlan :: Nil case _ => Nil } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala new file mode 100644 index 0000000000000..b92091b560b1c --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala @@ -0,0 +1,177 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution + +import java.util.{List => JList, Map => JMap} + +import net.razorvine.pickle.{Pickler, Unpickler} +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.api.python.PythonRDD +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.{Accumulator, Logging => SparkLogging} + +import scala.collection.JavaConversions._ + +/** + * A serialized version of a Python lambda function. Suitable for use in a [[PythonRDD]]. + */ +private[spark] case class PythonUDF( + name: String, + command: Array[Byte], + envVars: JMap[String, String], + pythonIncludes: JList[String], + pythonExec: String, + accumulator: Accumulator[JList[Array[Byte]]], + dataType: DataType, + children: Seq[Expression]) extends Expression with SparkLogging { + + override def toString = s"PythonUDF#$name(${children.mkString(",")})" + + def nullable: Boolean = true + def references: Set[Attribute] = children.flatMap(_.references).toSet + + override def eval(input: Row) = sys.error("PythonUDFs can not be directly evaluated.") +} + +/** + * Extracts PythonUDFs from operators, rewriting the query plan so that the UDF can be evaluated + * alone in a batch. + * + * This has the limitation that the input to the Python UDF is not allowed include attributes from + * multiple child operators. + */ +private[spark] object ExtractPythonUdfs extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan) = plan transform { + // Skip EvaluatePython nodes. + case p: EvaluatePython => p + + case l: LogicalPlan => + // Extract any PythonUDFs from the current operator. + val udfs = l.expressions.flatMap(_.collect { case udf: PythonUDF => udf}) + if (udfs.isEmpty) { + // If there aren't any, we are done. + l + } else { + // Pick the UDF we are going to evaluate (TODO: Support evaluating multiple UDFs at a time) + // If there is more than one, we will add another evaluation operator in a subsequent pass. + val udf = udfs.head + + var evaluation: EvaluatePython = null + + // Rewrite the child that has the input required for the UDF + val newChildren = l.children.map { child => + // Check to make sure that the UDF can be evaluated with only the input of this child. + // Other cases are disallowed as they are ambiguous or would require a cartisian product. + if (udf.references.subsetOf(child.outputSet)) { + evaluation = EvaluatePython(udf, child) + evaluation + } else if (udf.references.intersect(child.outputSet).nonEmpty) { + sys.error(s"Invalid PythonUDF $udf, requires attributes from more than one child.") + } else { + child + } + } + + assert(evaluation != null, "Unable to evaluate PythonUDF. Missing input attributes.") + + // Trim away the new UDF value if it was only used for filtering or something. + logical.Project( + l.output, + l.transformExpressions { + case p: PythonUDF if p.id == udf.id => evaluation.resultAttribute + }.withNewChildren(newChildren)) + } + } +} + +/** + * :: DeveloperApi :: + * Evaluates a [[PythonUDF]], appending the result to the end of the input tuple. + */ +@DeveloperApi +case class EvaluatePython(udf: PythonUDF, child: LogicalPlan) extends logical.UnaryNode { + val resultAttribute = AttributeReference("pythonUDF", udf.dataType, nullable=true)() + + def references = Set.empty + def output = child.output :+ resultAttribute +} + +/** + * :: DeveloperApi :: + * Uses PythonRDD to evaluate a [[PythonUDF]], one partition of tuples at a time. The input + * data is cached and zipped with the result of the udf evaluation. + */ +@DeveloperApi +case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: SparkPlan) + extends SparkPlan { + def children = child :: Nil + + def execute() = { + // TODO: Clean up after ourselves? + val childResults = child.execute().map(_.copy()).cache() + + val parent = childResults.mapPartitions { iter => + val pickle = new Pickler + val currentRow = newMutableProjection(udf.children, child.output)() + iter.grouped(1000).map { inputRows => + val toBePickled = inputRows.map(currentRow(_).toArray).toArray + pickle.dumps(toBePickled) + } + } + + val pyRDD = new PythonRDD( + parent, + udf.command, + udf.envVars, + udf.pythonIncludes, + false, + udf.pythonExec, + Seq[Broadcast[Array[Byte]]](), + udf.accumulator + ).mapPartitions { iter => + val pickle = new Unpickler + iter.flatMap { pickedResult => + val unpickledBatch = pickle.loads(pickedResult) + unpickledBatch.asInstanceOf[java.util.ArrayList[Any]] + } + }.mapPartitions { iter => + val row = new GenericMutableRow(1) + iter.map { result => + row(0) = udf.dataType match { + case StringType => result.toString + case other => result + } + row: Row + } + } + + childResults.zip(pyRDD).mapPartitions { iter => + val joinedRow = new JoinedRow() + iter.map { + case (row, udfResult) => + joinedRow(row, udfResult) + } + } + } +} diff --git a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaAPISuite.java b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaAPISuite.java new file mode 100644 index 0000000000000..a9a11285def54 --- /dev/null +++ b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaAPISuite.java @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +import java.io.Serializable; + +import org.apache.spark.sql.api.java.UDF1; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.junit.runners.Suite; +import org.junit.runner.RunWith; + +import org.apache.spark.api.java.JavaSparkContext; + +// The test suite itself is Serializable so that anonymous Function implementations can be +// serialized, as an alternative to converting these anonymous classes to static inner classes; +// see http://stackoverflow.com/questions/758570/. +public class JavaAPISuite implements Serializable { + private transient JavaSparkContext sc; + private transient JavaSQLContext sqlContext; + + @Before + public void setUp() { + sc = new JavaSparkContext("local", "JavaAPISuite"); + sqlContext = new JavaSQLContext(sc); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + } + + @SuppressWarnings("unchecked") + @Test + public void udf1Test() { + // With Java 8 lambdas: + // sqlContext.registerFunction( + // "stringLengthTest", (String str) -> str.length(), DataType.IntegerType); + + sqlContext.registerFunction("stringLengthTest", new UDF1() { + @Override + public Integer call(String str) throws Exception { + return str.length(); + } + }, DataType.IntegerType); + + // TODO: Why do we need this cast? + Row result = (Row) sqlContext.sql("SELECT stringLengthTest('test')").first(); + assert(result.getInt(0) == 4); + } + + @SuppressWarnings("unchecked") + @Test + public void udf2Test() { + // With Java 8 lambdas: + // sqlContext.registerFunction( + // "stringLengthTest", + // (String str1, String str2) -> str1.length() + str2.length, + // DataType.IntegerType); + + sqlContext.registerFunction("stringLengthTest", new UDF2() { + @Override + public Integer call(String str1, String str2) throws Exception { + return str1.length() + str2.length(); + } + }, DataType.IntegerType); + + // TODO: Why do we need this cast? + Row result = (Row) sqlContext.sql("SELECT stringLengthTest('test', 'test2')").first(); + assert(result.getInt(0) == 9); + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/InsertIntoSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/InsertIntoSuite.scala index 4f0b85f26254b..23a711d08c58b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/InsertIntoSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/InsertIntoSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql -import java.io.File +import _root_.java.io.File /* Implicits */ import org.apache.spark.sql.test.TestSQLContext._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala new file mode 100644 index 0000000000000..76aa9b0081d7e --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.test._ + +/* Implicits */ +import TestSQLContext._ + +class UDFSuite extends QueryTest { + + test("Simple UDF") { + registerFunction("strLenScala", (_: String).length) + assert(sql("SELECT strLenScala('test')").first().getInt(0) === 4) + } + + test("TwoArgument UDF") { + registerFunction("strLenScala", (_: String).length + (_:Int)) + assert(sql("SELECT strLenScala('test', 1)").first().getInt(0) === 5) + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 2c7270d9f83a9..3c70b3f0921a5 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -23,7 +23,7 @@ import java.util.{ArrayList => JArrayList} import scala.collection.JavaConversions._ import scala.language.implicitConversions -import scala.reflect.runtime.universe.TypeTag +import scala.reflect.runtime.universe.{TypeTag, typeTag} import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.ql.Driver @@ -35,8 +35,9 @@ import org.apache.spark.SparkContext import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.ScalaReflection -import org.apache.spark.sql.catalyst.analysis.{Analyzer, OverrideCatalog} +import org.apache.spark.sql.catalyst.analysis.{OverrideFunctionRegistry, Analyzer, OverrideCatalog} import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.execution.ExtractPythonUdfs import org.apache.spark.sql.execution.QueryExecutionException import org.apache.spark.sql.execution.{Command => PhysicalCommand} import org.apache.spark.sql.hive.execution.DescribeHiveTableCommand @@ -155,10 +156,14 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { } } + // Note that HiveUDFs will be overridden by functions registered in this context. + override protected[sql] lazy val functionRegistry = + new HiveFunctionRegistry with OverrideFunctionRegistry + /* An analyzer that uses the Hive metastore. */ @transient override protected[sql] lazy val analyzer = - new Analyzer(catalog, HiveFunctionRegistry, caseSensitive = false) + new Analyzer(catalog, functionRegistry, caseSensitive = false) /** * Runs the specified SQL query using Hive. @@ -250,7 +255,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { protected[sql] abstract class QueryExecution extends super.QueryExecution { // TODO: Create mixin for the analyzer instead of overriding things here. override lazy val optimizedPlan = - optimizer(catalog.PreInsertionCasts(catalog.CreateTables(analyzed))) + optimizer(ExtractPythonUdfs(catalog.PreInsertionCasts(catalog.CreateTables(analyzed)))) override lazy val toRdd: RDD[Row] = executedPlan.execute().map(_.copy()) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala index 728452a25a00e..c605e8adcfb0f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala @@ -297,8 +297,8 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { def reset() { try { // HACK: Hive is too noisy by default. - org.apache.log4j.LogManager.getCurrentLoggers.foreach { logger => - logger.asInstanceOf[org.apache.log4j.Logger].setLevel(org.apache.log4j.Level.WARN) + org.apache.log4j.LogManager.getCurrentLoggers.foreach { log => + log.asInstanceOf[org.apache.log4j.Logger].setLevel(org.apache.log4j.Level.WARN) } // It is important that we RESET first as broken hooks that might have been set could break diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala index d181921269b56..179aac5cbd5cd 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala @@ -34,7 +34,8 @@ import org.apache.spark.util.Utils.getContextOrSparkClassLoader /* Implicit conversions */ import scala.collection.JavaConversions._ -private[hive] object HiveFunctionRegistry extends analysis.FunctionRegistry with HiveInspectors { +private[hive] abstract class HiveFunctionRegistry + extends analysis.FunctionRegistry with HiveInspectors { def getFunctionInfo(name: String) = FunctionRegistry.getFunctionInfo(name) @@ -92,9 +93,8 @@ private[hive] abstract class HiveUdf extends Expression with Logging with HiveFu } private[hive] case class HiveSimpleUdf(functionClassName: String, children: Seq[Expression]) - extends HiveUdf { + extends HiveUdf with HiveInspectors { - import org.apache.spark.sql.hive.HiveFunctionRegistry._ type UDFType = UDF @transient diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/QueryTest.scala index 11d8b1f0a3d96..95921c3d7ae09 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -51,9 +51,9 @@ class QueryTest extends FunSuite { fail( s""" |Exception thrown while executing query: - |${rdd.logicalPlan} + |${rdd.queryExecution} |== Exception == - |$e + |${stackTraceToString(e)} """.stripMargin) } From 4230df4e1d6c59dc3405f46f5edf18c3825a5447 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Sat, 2 Aug 2014 16:48:07 -0700 Subject: [PATCH 007/231] [SPARK-2785][SQL] Remove assertions that throw when users try unsupported Hive commands. Author: Michael Armbrust Closes #1742 from marmbrus/asserts and squashes the following commits: 5182d54 [Michael Armbrust] Remove assertions that throw when users try unsupported Hive commands. (cherry picked from commit 198df11f1a9f419f820f47eba0e9f2ab371a824b) Signed-off-by: Michael Armbrust --- .../scala/org/apache/spark/sql/hive/HiveQl.scala | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 3d2eb1eefaeda..bc2fefafd58c8 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -297,8 +297,11 @@ private[hive] object HiveQl { matches.headOption } - assert(remainingNodes.isEmpty, - s"Unhandled clauses: ${remainingNodes.map(dumpTree(_)).mkString("\n")}") + if (remainingNodes.nonEmpty) { + sys.error( + s"""Unhandled clauses: ${remainingNodes.map(dumpTree(_)).mkString("\n")}. + |You are likely trying to use an unsupported Hive feature."""".stripMargin) + } clauses } @@ -748,7 +751,10 @@ private[hive] object HiveQl { case Token(allJoinTokens(joinToken), relation1 :: relation2 :: other) => - assert(other.size <= 1, s"Unhandled join child $other") + if (!(other.size <= 1)) { + sys.error(s"Unsupported join operation: $other") + } + val joinType = joinToken match { case "TOK_JOIN" => Inner case "TOK_RIGHTOUTERJOIN" => RightOuter @@ -756,7 +762,6 @@ private[hive] object HiveQl { case "TOK_FULLOUTERJOIN" => FullOuter case "TOK_LEFTSEMIJOIN" => LeftSemi } - assert(other.size <= 1, "Unhandled join clauses.") Join(nodeToRelation(relation1), nodeToRelation(relation2), joinType, From 460fad817da1fb6619d2456f637c1b7c7f5e8c7c Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Sat, 2 Aug 2014 17:12:49 -0700 Subject: [PATCH 008/231] [SPARK-2729][SQL] Added test case for SPARK-2729 This is a follow up of #1636. Author: Cheng Lian Closes #1738 from liancheng/test-for-spark-2729 and squashes the following commits: b13692a [Cheng Lian] Added test case for SPARK-2729 (cherry picked from commit 866cf1f822cfda22294054be026ef2d96307eb75) Signed-off-by: Michael Armbrust --- .../test/scala/org/apache/spark/sql/TestData.scala | 12 ++++++++++-- .../sql/columnar/InMemoryColumnarQuerySuite.scala | 12 ++++++++++++ 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala index 58cee21e8ad4c..088e6e3c843aa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala @@ -17,11 +17,13 @@ package org.apache.spark.sql +import java.sql.Timestamp + import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.test._ /* Implicits */ -import TestSQLContext._ +import org.apache.spark.sql.test.TestSQLContext._ case class TestData(key: Int, value: String) @@ -40,7 +42,7 @@ object TestData { LargeAndSmallInts(2147483646, 1) :: LargeAndSmallInts(3, 2) :: Nil) largeAndSmallInts.registerAsTable("largeAndSmallInts") - + case class TestData2(a: Int, b: Int) val testData2: SchemaRDD = TestSQLContext.sparkContext.parallelize( @@ -143,4 +145,10 @@ object TestData { "2, B2, false, null" :: "3, C3, true, null" :: "4, D4, true, 2147483644" :: Nil) + + case class TimestampField(time: Timestamp) + val timestamps = TestSQLContext.sparkContext.parallelize((1 to 3).map { i => + TimestampField(new Timestamp(i)) + }) + timestamps.registerAsTable("timestamps") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala index 86727b93f3659..b561b44ad7ee2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala @@ -73,4 +73,16 @@ class InMemoryColumnarQuerySuite extends QueryTest { sql("SELECT * FROM nullableRepeatedData"), nullableRepeatedData.collect().toSeq) } + + test("SPARK-2729 regression: timestamp data type") { + checkAnswer( + sql("SELECT time FROM timestamps"), + timestamps.collect().toSeq) + + TestSQLContext.cacheTable("timestamps") + + checkAnswer( + sql("SELECT time FROM timestamps"), + timestamps.collect().toSeq) + } } From 5ef828273deb4713a49700c56d51bdd980917cfd Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Sat, 2 Aug 2014 17:55:22 -0700 Subject: [PATCH 009/231] [SPARK-2797] [SQL] SchemaRDDs don't support unpersist() The cause is explained in https://issues.apache.org/jira/browse/SPARK-2797. Author: Yin Huai Closes #1745 from yhuai/SPARK-2797 and squashes the following commits: 7b1627d [Yin Huai] The unpersist method of the Scala RDD cannot be called without the input parameter (blocking) from PySpark. (cherry picked from commit d210022e96804e59e42ab902e53637e50884a9ab) Signed-off-by: Michael Armbrust --- python/pyspark/sql.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index e7c35ac1ffe02..36e50e49c9a9c 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -1589,9 +1589,9 @@ def persist(self, storageLevel): self._jschema_rdd.persist(javaStorageLevel) return self - def unpersist(self): + def unpersist(self, blocking=True): self.is_cached = False - self._jschema_rdd.unpersist() + self._jschema_rdd.unpersist(blocking) return self def checkpoint(self): From 5b30e001839a29e6c4bd1fc24bfa12d9166ef10c Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Sat, 2 Aug 2014 18:27:04 -0700 Subject: [PATCH 010/231] [SPARK-2739][SQL] Rename registerAsTable to registerTempTable There have been user complaints that the difference between `registerAsTable` and `saveAsTable` is too subtle. This PR addresses this by renaming `registerAsTable` to `registerTempTable`, which more clearly reflects what is happening. `registerAsTable` remains, but will cause a deprecation warning. Author: Michael Armbrust Closes #1743 from marmbrus/registerTempTable and squashes the following commits: d031348 [Michael Armbrust] Merge remote-tracking branch 'apache/master' into registerTempTable 4dff086 [Michael Armbrust] Fix .java files too 89a2f12 [Michael Armbrust] Merge remote-tracking branch 'apache/master' into registerTempTable 0b7b71e [Michael Armbrust] Rename registerAsTable to registerTempTable (cherry picked from commit 1a8043739dc1d9435def6ea3c6341498ba52b708) Signed-off-by: Michael Armbrust --- .../sbt_app_sql/src/main/scala/SqlApp.scala | 2 +- docs/sql-programming-guide.md | 18 ++++++------ .../spark/examples/sql/JavaSparkSQL.java | 8 +++--- .../spark/examples/sql/RDDRelation.scala | 4 +-- .../examples/sql/hive/HiveFromSpark.scala | 2 +- python/pyspark/sql.py | 12 +++++--- .../org/apache/spark/sql/SQLContext.scala | 4 +-- .../org/apache/spark/sql/SchemaRDD.scala | 2 +- .../org/apache/spark/sql/SchemaRDDLike.scala | 5 +++- .../spark/sql/api/java/JavaSQLContext.scala | 2 +- .../sql/api/java/JavaApplySchemaSuite.java | 6 ++-- .../apache/spark/sql/CachedTableSuite.scala | 2 +- .../apache/spark/sql/InsertIntoSuite.scala | 4 +-- .../org/apache/spark/sql/JoinSuite.scala | 4 +-- .../org/apache/spark/sql/SQLQuerySuite.scala | 6 ++-- .../sql/ScalaReflectionRelationSuite.scala | 8 +++--- .../scala/org/apache/spark/sql/TestData.scala | 28 +++++++++---------- .../spark/sql/api/java/JavaSQLSuite.scala | 10 +++---- .../org/apache/spark/sql/json/JsonSuite.scala | 22 +++++++-------- .../spark/sql/parquet/ParquetQuerySuite.scala | 26 ++++++++--------- .../sql/hive/InsertIntoHiveTableSuite.scala | 2 +- .../sql/hive/api/java/JavaHiveQLSuite.scala | 4 +-- .../sql/hive/execution/HiveQuerySuite.scala | 6 ++-- .../hive/execution/HiveResolutionSuite.scala | 4 +-- .../spark/sql/parquet/HiveParquetSuite.scala | 8 +++--- 25 files changed, 103 insertions(+), 96 deletions(-) diff --git a/dev/audit-release/sbt_app_sql/src/main/scala/SqlApp.scala b/dev/audit-release/sbt_app_sql/src/main/scala/SqlApp.scala index 50af90c213b5a..d888de929fdda 100644 --- a/dev/audit-release/sbt_app_sql/src/main/scala/SqlApp.scala +++ b/dev/audit-release/sbt_app_sql/src/main/scala/SqlApp.scala @@ -38,7 +38,7 @@ object SparkSqlExample { import sqlContext._ val people = sc.makeRDD(1 to 100, 10).map(x => Person(s"Name$x", x)) - people.registerAsTable("people") + people.registerTempTable("people") val teenagers = sql("SELECT name FROM people WHERE age >= 13 AND age <= 19") val teenagerNames = teenagers.map(t => "Name: " + t(0)).collect() teenagerNames.foreach(println) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 7261badd411a9..0465468084cee 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -142,7 +142,7 @@ case class Person(name: String, age: Int) // Create an RDD of Person objects and register it as a table. val people = sc.textFile("examples/src/main/resources/people.txt").map(_.split(",")).map(p => Person(p(0), p(1).trim.toInt)) -people.registerAsTable("people") +people.registerTempTable("people") // SQL statements can be run by using the sql methods provided by sqlContext. val teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19") @@ -210,7 +210,7 @@ JavaRDD people = sc.textFile("examples/src/main/resources/people.txt").m // Apply a schema to an RDD of JavaBeans and register it as a table. JavaSchemaRDD schemaPeople = sqlContext.applySchema(people, Person.class); -schemaPeople.registerAsTable("people"); +schemaPeople.registerTempTable("people"); // SQL can be run over RDDs that have been registered as tables. JavaSchemaRDD teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19") @@ -248,7 +248,7 @@ people = parts.map(lambda p: {"name": p[0], "age": int(p[1])}) # In future versions of PySpark we would like to add support for registering RDDs with other # datatypes as tables schemaPeople = sqlContext.inferSchema(people) -schemaPeople.registerAsTable("people") +schemaPeople.registerTempTable("people") # SQL can be run over SchemaRDDs that have been registered as a table. teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19") @@ -292,7 +292,7 @@ people.saveAsParquetFile("people.parquet") val parquetFile = sqlContext.parquetFile("people.parquet") //Parquet files can also be registered as tables and then used in SQL statements. -parquetFile.registerAsTable("parquetFile") +parquetFile.registerTempTable("parquetFile") val teenagers = sqlContext.sql("SELECT name FROM parquetFile WHERE age >= 13 AND age <= 19") teenagers.map(t => "Name: " + t(0)).collect().foreach(println) {% endhighlight %} @@ -314,7 +314,7 @@ schemaPeople.saveAsParquetFile("people.parquet"); JavaSchemaRDD parquetFile = sqlContext.parquetFile("people.parquet"); //Parquet files can also be registered as tables and then used in SQL statements. -parquetFile.registerAsTable("parquetFile"); +parquetFile.registerTempTable("parquetFile"); JavaSchemaRDD teenagers = sqlContext.sql("SELECT name FROM parquetFile WHERE age >= 13 AND age <= 19"); List teenagerNames = teenagers.map(new Function() { public String call(Row row) { @@ -340,7 +340,7 @@ schemaPeople.saveAsParquetFile("people.parquet") parquetFile = sqlContext.parquetFile("people.parquet") # Parquet files can also be registered as tables and then used in SQL statements. -parquetFile.registerAsTable("parquetFile"); +parquetFile.registerTempTable("parquetFile"); teenagers = sqlContext.sql("SELECT name FROM parquetFile WHERE age >= 13 AND age <= 19") teenNames = teenagers.map(lambda p: "Name: " + p.name) for teenName in teenNames.collect(): @@ -378,7 +378,7 @@ people.printSchema() // |-- name: StringType // Register this SchemaRDD as a table. -people.registerAsTable("people") +people.registerTempTable("people") // SQL statements can be run by using the sql methods provided by sqlContext. val teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19") @@ -416,7 +416,7 @@ people.printSchema(); // |-- name: StringType // Register this JavaSchemaRDD as a table. -people.registerAsTable("people"); +people.registerTempTable("people"); // SQL statements can be run by using the sql methods provided by sqlContext. JavaSchemaRDD teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19"); @@ -455,7 +455,7 @@ people.printSchema() # |-- name: StringType # Register this SchemaRDD as a table. -people.registerAsTable("people") +people.registerTempTable("people") # SQL statements can be run by using the sql methods provided by sqlContext. teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19") diff --git a/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java b/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java index 607df3eddd550..898297dc658ba 100644 --- a/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java +++ b/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java @@ -74,7 +74,7 @@ public Person call(String line) throws Exception { // Apply a schema to an RDD of Java Beans and register it as a table. JavaSchemaRDD schemaPeople = sqlCtx.applySchema(people, Person.class); - schemaPeople.registerAsTable("people"); + schemaPeople.registerTempTable("people"); // SQL can be run over RDDs that have been registered as tables. JavaSchemaRDD teenagers = sqlCtx.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19"); @@ -100,7 +100,7 @@ public String call(Row row) { JavaSchemaRDD parquetFile = sqlCtx.parquetFile("people.parquet"); //Parquet files can also be registered as tables and then used in SQL statements. - parquetFile.registerAsTable("parquetFile"); + parquetFile.registerTempTable("parquetFile"); JavaSchemaRDD teenagers2 = sqlCtx.sql("SELECT name FROM parquetFile WHERE age >= 13 AND age <= 19"); teenagerNames = teenagers2.map(new Function() { @@ -128,7 +128,7 @@ public String call(Row row) { // |-- name: StringType // Register this JavaSchemaRDD as a table. - peopleFromJsonFile.registerAsTable("people"); + peopleFromJsonFile.registerTempTable("people"); // SQL statements can be run by using the sql methods provided by sqlCtx. JavaSchemaRDD teenagers3 = sqlCtx.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19"); @@ -158,7 +158,7 @@ public String call(Row row) { // | |-- state: StringType // |-- name: StringType - peopleFromJsonRDD.registerAsTable("people2"); + peopleFromJsonRDD.registerTempTable("people2"); JavaSchemaRDD peopleWithCity = sqlCtx.sql("SELECT name, address.city FROM people2"); List nameAndCity = peopleWithCity.map(new Function() { diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala b/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala index 63db688bfb8c0..d56d64c564200 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala @@ -36,7 +36,7 @@ object RDDRelation { val rdd = sc.parallelize((1 to 100).map(i => Record(i, s"val_$i"))) // Any RDD containing case classes can be registered as a table. The schema of the table is // automatically inferred using scala reflection. - rdd.registerAsTable("records") + rdd.registerTempTable("records") // Once tables have been registered, you can run SQL queries over them. println("Result of SELECT *:") @@ -66,7 +66,7 @@ object RDDRelation { parquetFile.where('key === 1).select('value as 'a).collect().foreach(println) // These files can also be registered as tables. - parquetFile.registerAsTable("parquetFile") + parquetFile.registerTempTable("parquetFile") sql("SELECT * FROM parquetFile").collect().foreach(println) } } diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala b/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala index dc5290fb4f10e..12530c8490b09 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala @@ -56,7 +56,7 @@ object HiveFromSpark { // You can also register RDDs as temporary tables within a HiveContext. val rdd = sc.parallelize((1 to 100).map(i => Record(i, s"val_$i"))) - rdd.registerAsTable("records") + rdd.registerTempTable("records") // Queries can then join RDD data with data stored in Hive. println("Result of SELECT *:") diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 36e50e49c9a9c..42b738e112809 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -909,7 +909,7 @@ def __init__(self, sparkContext, sqlContext=None): ... b=True, list=[1, 2, 3], dict={"s": 0}, row=Row(a=1), ... time=datetime(2014, 8, 1, 14, 1, 5))]) >>> srdd = sqlCtx.inferSchema(allTypes) - >>> srdd.registerAsTable("allTypes") + >>> srdd.registerTempTable("allTypes") >>> sqlCtx.sql('select i+1, d+1, not b, list[1], dict["s"], time, row.a ' ... 'from allTypes where b and i > 0').collect() [Row(c0=2, c1=2.0, c2=False, c3=2, c4=0...8, 1, 14, 1, 5), a=1)] @@ -1486,19 +1486,23 @@ def saveAsParquetFile(self, path): """ self._jschema_rdd.saveAsParquetFile(path) - def registerAsTable(self, name): + def registerTempTable(self, name): """Registers this RDD as a temporary table using the given name. The lifetime of this temporary table is tied to the L{SQLContext} that was used to create this SchemaRDD. >>> srdd = sqlCtx.inferSchema(rdd) - >>> srdd.registerAsTable("test") + >>> srdd.registerTempTable("test") >>> srdd2 = sqlCtx.sql("select * from test") >>> sorted(srdd.collect()) == sorted(srdd2.collect()) True """ - self._jschema_rdd.registerAsTable(name) + self._jschema_rdd.registerTempTable(name) + + def registerAsTable(self, name): + warnings.warn("Use registerTempTable instead of registerAsTable.", DeprecationWarning) + self.registerTempTable(name) def insertInto(self, tableName, overwrite=False): """Inserts the contents of this SchemaRDD into the specified table. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 33931e5d996f5..567f4dca991b2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -116,7 +116,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * // |-- name: string (nullable = false) * // |-- age: integer (nullable = true) * - * peopleSchemaRDD.registerAsTable("people") + * peopleSchemaRDD.registerTempTable("people") * sqlContext.sql("select name from people").collect.foreach(println) * }}} * @@ -212,7 +212,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * import sqlContext._ * * case class Person(name: String, age: Int) - * createParquetFile[Person]("path/to/file.parquet").registerAsTable("people") + * createParquetFile[Person]("path/to/file.parquet").registerTempTable("people") * sql("INSERT INTO people SELECT 'michael', 29") * }}} * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala index d34f62dc8865e..57df79321b35d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala @@ -67,7 +67,7 @@ import org.apache.spark.api.java.JavaRDD * val rdd = sc.parallelize((1 to 100).map(i => Record(i, s"val_$i"))) * // Any RDD containing case classes can be registered as a table. The schema of the table is * // automatically inferred using scala reflection. - * rdd.registerAsTable("records") + * rdd.registerTempTable("records") * * val results: SchemaRDD = sql("SELECT * FROM records") * }}} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala index 6a20def475822..2f3033a5f94f0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala @@ -83,10 +83,13 @@ private[sql] trait SchemaRDDLike { * * @group schema */ - def registerAsTable(tableName: String): Unit = { + def registerTempTable(tableName: String): Unit = { sqlContext.registerRDDAsTable(baseSchemaRDD, tableName) } + @deprecated("Use registerTempTable instead of registerAsTable.", "1.1") + def registerAsTable(tableName: String): Unit = registerTempTable(tableName) + /** * :: Experimental :: * Adds the rows from this RDD to the specified table, optionally overwriting the existing data. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala index ae45193ed15d3..dbaa16e8b0c68 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala @@ -52,7 +52,7 @@ class JavaSQLContext(val sqlContext: SQLContext) extends UDFRegistration { * {{{ * JavaSQLContext sqlCtx = new JavaSQLContext(...) * - * sqlCtx.createParquetFile(Person.class, "path/to/file.parquet").registerAsTable("people") + * sqlCtx.createParquetFile(Person.class, "path/to/file.parquet").registerTempTable("people") * sqlCtx.sql("INSERT INTO people SELECT 'michael', 29") * }}} * diff --git a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java index 3c92906d82864..33e5020bc636a 100644 --- a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java +++ b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java @@ -98,7 +98,7 @@ public Row call(Person person) throws Exception { StructType schema = DataType.createStructType(fields); JavaSchemaRDD schemaRDD = javaSqlCtx.applySchema(rowRDD, schema); - schemaRDD.registerAsTable("people"); + schemaRDD.registerTempTable("people"); List actual = javaSqlCtx.sql("SELECT * FROM people").collect(); List expected = new ArrayList(2); @@ -149,14 +149,14 @@ public void applySchemaToJSON() { JavaSchemaRDD schemaRDD1 = javaSqlCtx.jsonRDD(jsonRDD); StructType actualSchema1 = schemaRDD1.schema(); Assert.assertEquals(expectedSchema, actualSchema1); - schemaRDD1.registerAsTable("jsonTable1"); + schemaRDD1.registerTempTable("jsonTable1"); List actual1 = javaSqlCtx.sql("select * from jsonTable1").collect(); Assert.assertEquals(expectedResult, actual1); JavaSchemaRDD schemaRDD2 = javaSqlCtx.jsonRDD(jsonRDD, expectedSchema); StructType actualSchema2 = schemaRDD2.schema(); Assert.assertEquals(expectedSchema, actualSchema2); - schemaRDD1.registerAsTable("jsonTable2"); + schemaRDD1.registerTempTable("jsonTable2"); List actual2 = javaSqlCtx.sql("select * from jsonTable2").collect(); Assert.assertEquals(expectedResult, actual2); } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index c3c0dcb1aa00b..fbf9bd9dbcdea 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -78,7 +78,7 @@ class CachedTableSuite extends QueryTest { } test("SELECT Star Cached Table") { - TestSQLContext.sql("SELECT * FROM testData").registerAsTable("selectStar") + TestSQLContext.sql("SELECT * FROM testData").registerTempTable("selectStar") TestSQLContext.cacheTable("selectStar") TestSQLContext.sql("SELECT * FROM selectStar WHERE key = 1").collect() TestSQLContext.uncacheTable("selectStar") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/InsertIntoSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/InsertIntoSuite.scala index 23a711d08c58b..c87d762751e6d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/InsertIntoSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/InsertIntoSuite.scala @@ -31,7 +31,7 @@ class InsertIntoSuite extends QueryTest { testFilePath.delete() testFilePath.deleteOnExit() val testFile = createParquetFile[TestData](testFilePath.getCanonicalPath) - testFile.registerAsTable("createAndInsertTest") + testFile.registerTempTable("createAndInsertTest") // Add some data. testData.insertInto("createAndInsertTest") @@ -86,7 +86,7 @@ class InsertIntoSuite extends QueryTest { testFilePath.delete() testFilePath.deleteOnExit() val testFile = createParquetFile[TestData](testFilePath.getCanonicalPath) - testFile.registerAsTable("createAndInsertSQLTest") + testFile.registerTempTable("createAndInsertSQLTest") sql("INSERT INTO createAndInsertSQLTest SELECT * FROM testData") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 2fc80588182d9..6c7697ece8c56 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -285,8 +285,8 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { } test("full outer join") { - upperCaseData.where('N <= 4).registerAsTable("left") - upperCaseData.where('N >= 3).registerAsTable("right") + upperCaseData.where('N <= 4).registerTempTable("left") + upperCaseData.where('N >= 3).registerTempTable("right") val left = UnresolvedRelation(None, "left", None) val right = UnresolvedRelation(None, "right", None) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 5c571d35d1bb9..9b2a36d33fca7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -461,7 +461,7 @@ class SQLQuerySuite extends QueryTest { } val schemaRDD1 = applySchema(rowRDD1, schema1) - schemaRDD1.registerAsTable("applySchema1") + schemaRDD1.registerTempTable("applySchema1") checkAnswer( sql("SELECT * FROM applySchema1"), (1, "A1", true, null) :: @@ -491,7 +491,7 @@ class SQLQuerySuite extends QueryTest { } val schemaRDD2 = applySchema(rowRDD2, schema2) - schemaRDD2.registerAsTable("applySchema2") + schemaRDD2.registerTempTable("applySchema2") checkAnswer( sql("SELECT * FROM applySchema2"), (Seq(1, true), Map("A1" -> null)) :: @@ -516,7 +516,7 @@ class SQLQuerySuite extends QueryTest { } val schemaRDD3 = applySchema(rowRDD3, schema2) - schemaRDD3.registerAsTable("applySchema3") + schemaRDD3.registerTempTable("applySchema3") checkAnswer( sql("SELECT f1.f11, f2['D4'] FROM applySchema3"), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala index f2934da9a031d..5b84c658db942 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala @@ -61,7 +61,7 @@ class ScalaReflectionRelationSuite extends FunSuite { val data = ReflectData("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true, BigDecimal(1), new Timestamp(12345), Seq(1,2,3)) val rdd = sparkContext.parallelize(data :: Nil) - rdd.registerAsTable("reflectData") + rdd.registerTempTable("reflectData") assert(sql("SELECT * FROM reflectData").collect().head === data.productIterator.toSeq) } @@ -69,7 +69,7 @@ class ScalaReflectionRelationSuite extends FunSuite { test("query case class RDD with nulls") { val data = NullReflectData(null, null, null, null, null, null, null) val rdd = sparkContext.parallelize(data :: Nil) - rdd.registerAsTable("reflectNullData") + rdd.registerTempTable("reflectNullData") assert(sql("SELECT * FROM reflectNullData").collect().head === Seq.fill(7)(null)) } @@ -77,7 +77,7 @@ class ScalaReflectionRelationSuite extends FunSuite { test("query case class RDD with Nones") { val data = OptionalReflectData(None, None, None, None, None, None, None) val rdd = sparkContext.parallelize(data :: Nil) - rdd.registerAsTable("reflectOptionalData") + rdd.registerTempTable("reflectOptionalData") assert(sql("SELECT * FROM reflectOptionalData").collect().head === Seq.fill(7)(null)) } @@ -85,7 +85,7 @@ class ScalaReflectionRelationSuite extends FunSuite { // Equality is broken for Arrays, so we test that separately. test("query binary data") { val rdd = sparkContext.parallelize(ReflectBinary(Array[Byte](1)) :: Nil) - rdd.registerAsTable("reflectBinary") + rdd.registerTempTable("reflectBinary") val result = sql("SELECT data FROM reflectBinary").collect().head(0).asInstanceOf[Array[Byte]] assert(result.toSeq === Seq[Byte](1)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala index 088e6e3c843aa..c3ec82fb69778 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala @@ -30,7 +30,7 @@ case class TestData(key: Int, value: String) object TestData { val testData: SchemaRDD = TestSQLContext.sparkContext.parallelize( (1 to 100).map(i => TestData(i, i.toString))) - testData.registerAsTable("testData") + testData.registerTempTable("testData") case class LargeAndSmallInts(a: Int, b: Int) val largeAndSmallInts: SchemaRDD = @@ -41,7 +41,7 @@ object TestData { LargeAndSmallInts(2, 2) :: LargeAndSmallInts(2147483646, 1) :: LargeAndSmallInts(3, 2) :: Nil) - largeAndSmallInts.registerAsTable("largeAndSmallInts") + largeAndSmallInts.registerTempTable("largeAndSmallInts") case class TestData2(a: Int, b: Int) val testData2: SchemaRDD = @@ -52,7 +52,7 @@ object TestData { TestData2(2, 2) :: TestData2(3, 1) :: TestData2(3, 2) :: Nil) - testData2.registerAsTable("testData2") + testData2.registerTempTable("testData2") // TODO: There is no way to express null primitives as case classes currently... val testData3 = @@ -71,7 +71,7 @@ object TestData { UpperCaseData(4, "D") :: UpperCaseData(5, "E") :: UpperCaseData(6, "F") :: Nil) - upperCaseData.registerAsTable("upperCaseData") + upperCaseData.registerTempTable("upperCaseData") case class LowerCaseData(n: Int, l: String) val lowerCaseData = @@ -80,14 +80,14 @@ object TestData { LowerCaseData(2, "b") :: LowerCaseData(3, "c") :: LowerCaseData(4, "d") :: Nil) - lowerCaseData.registerAsTable("lowerCaseData") + lowerCaseData.registerTempTable("lowerCaseData") case class ArrayData(data: Seq[Int], nestedData: Seq[Seq[Int]]) val arrayData = TestSQLContext.sparkContext.parallelize( ArrayData(Seq(1,2,3), Seq(Seq(1,2,3))) :: ArrayData(Seq(2,3,4), Seq(Seq(2,3,4))) :: Nil) - arrayData.registerAsTable("arrayData") + arrayData.registerTempTable("arrayData") case class MapData(data: Map[Int, String]) val mapData = @@ -97,18 +97,18 @@ object TestData { MapData(Map(1 -> "a3", 2 -> "b3", 3 -> "c3")) :: MapData(Map(1 -> "a4", 2 -> "b4")) :: MapData(Map(1 -> "a5")) :: Nil) - mapData.registerAsTable("mapData") + mapData.registerTempTable("mapData") case class StringData(s: String) val repeatedData = TestSQLContext.sparkContext.parallelize(List.fill(2)(StringData("test"))) - repeatedData.registerAsTable("repeatedData") + repeatedData.registerTempTable("repeatedData") val nullableRepeatedData = TestSQLContext.sparkContext.parallelize( List.fill(2)(StringData(null)) ++ List.fill(2)(StringData("test"))) - nullableRepeatedData.registerAsTable("nullableRepeatedData") + nullableRepeatedData.registerTempTable("nullableRepeatedData") case class NullInts(a: Integer) val nullInts = @@ -118,7 +118,7 @@ object TestData { NullInts(3) :: NullInts(null) :: Nil ) - nullInts.registerAsTable("nullInts") + nullInts.registerTempTable("nullInts") val allNulls = TestSQLContext.sparkContext.parallelize( @@ -126,7 +126,7 @@ object TestData { NullInts(null) :: NullInts(null) :: NullInts(null) :: Nil) - allNulls.registerAsTable("allNulls") + allNulls.registerTempTable("allNulls") case class NullStrings(n: Int, s: String) val nullStrings = @@ -134,10 +134,10 @@ object TestData { NullStrings(1, "abc") :: NullStrings(2, "ABC") :: NullStrings(3, null) :: Nil) - nullStrings.registerAsTable("nullStrings") + nullStrings.registerTempTable("nullStrings") case class TableName(tableName: String) - TestSQLContext.sparkContext.parallelize(TableName("test") :: Nil).registerAsTable("tableName") + TestSQLContext.sparkContext.parallelize(TableName("test") :: Nil).registerTempTable("tableName") val unparsedStrings = TestSQLContext.sparkContext.parallelize( @@ -150,5 +150,5 @@ object TestData { val timestamps = TestSQLContext.sparkContext.parallelize((1 to 3).map { i => TimestampField(new Timestamp(i)) }) - timestamps.registerAsTable("timestamps") + timestamps.registerTempTable("timestamps") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/api/java/JavaSQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/api/java/JavaSQLSuite.scala index 020baf0c7ec6f..203ff847e94cc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/api/java/JavaSQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/api/java/JavaSQLSuite.scala @@ -59,7 +59,7 @@ class JavaSQLSuite extends FunSuite { val rdd = javaCtx.parallelize(person :: Nil) val schemaRDD = javaSqlCtx.applySchema(rdd, classOf[PersonBean]) - schemaRDD.registerAsTable("people") + schemaRDD.registerTempTable("people") javaSqlCtx.sql("SELECT * FROM people").collect() } @@ -76,7 +76,7 @@ class JavaSQLSuite extends FunSuite { val rdd = javaCtx.parallelize(bean :: Nil) val schemaRDD = javaSqlCtx.applySchema(rdd, classOf[AllTypesBean]) - schemaRDD.registerAsTable("allTypes") + schemaRDD.registerTempTable("allTypes") assert( javaSqlCtx.sql( @@ -101,7 +101,7 @@ class JavaSQLSuite extends FunSuite { val rdd = javaCtx.parallelize(bean :: Nil) val schemaRDD = javaSqlCtx.applySchema(rdd, classOf[AllTypesBean]) - schemaRDD.registerAsTable("allTypes") + schemaRDD.registerTempTable("allTypes") assert( javaSqlCtx.sql( @@ -127,7 +127,7 @@ class JavaSQLSuite extends FunSuite { var schemaRDD = javaSqlCtx.jsonRDD(rdd) - schemaRDD.registerAsTable("jsonTable1") + schemaRDD.registerTempTable("jsonTable1") assert( javaSqlCtx.sql("select * from jsonTable1").collect.head.row === @@ -144,7 +144,7 @@ class JavaSQLSuite extends FunSuite { rdd.saveAsTextFile(path) schemaRDD = javaSqlCtx.jsonFile(path) - schemaRDD.registerAsTable("jsonTable2") + schemaRDD.registerTempTable("jsonTable2") assert( javaSqlCtx.sql("select * from jsonTable2").collect.head.row === diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala index 9d9cfdd7c92e3..75c0589eb208e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala @@ -183,7 +183,7 @@ class JsonSuite extends QueryTest { assert(expectedSchema === jsonSchemaRDD.schema) - jsonSchemaRDD.registerAsTable("jsonTable") + jsonSchemaRDD.registerTempTable("jsonTable") checkAnswer( sql("select * from jsonTable"), @@ -223,7 +223,7 @@ class JsonSuite extends QueryTest { assert(expectedSchema === jsonSchemaRDD.schema) - jsonSchemaRDD.registerAsTable("jsonTable") + jsonSchemaRDD.registerTempTable("jsonTable") // Access elements of a primitive array. checkAnswer( @@ -291,7 +291,7 @@ class JsonSuite extends QueryTest { ignore("Complex field and type inferring (Ignored)") { val jsonSchemaRDD = jsonRDD(complexFieldAndType) - jsonSchemaRDD.registerAsTable("jsonTable") + jsonSchemaRDD.registerTempTable("jsonTable") // Right now, "field1" and "field2" are treated as aliases. We should fix it. checkAnswer( @@ -320,7 +320,7 @@ class JsonSuite extends QueryTest { assert(expectedSchema === jsonSchemaRDD.schema) - jsonSchemaRDD.registerAsTable("jsonTable") + jsonSchemaRDD.registerTempTable("jsonTable") checkAnswer( sql("select * from jsonTable"), @@ -374,7 +374,7 @@ class JsonSuite extends QueryTest { ignore("Type conflict in primitive field values (Ignored)") { val jsonSchemaRDD = jsonRDD(primitiveFieldValueTypeConflict) - jsonSchemaRDD.registerAsTable("jsonTable") + jsonSchemaRDD.registerTempTable("jsonTable") // Right now, the analyzer does not promote strings in a boolean expreesion. // Number and Boolean conflict: resolve the type as boolean in this query. @@ -445,7 +445,7 @@ class JsonSuite extends QueryTest { assert(expectedSchema === jsonSchemaRDD.schema) - jsonSchemaRDD.registerAsTable("jsonTable") + jsonSchemaRDD.registerTempTable("jsonTable") checkAnswer( sql("select * from jsonTable"), @@ -466,7 +466,7 @@ class JsonSuite extends QueryTest { assert(expectedSchema === jsonSchemaRDD.schema) - jsonSchemaRDD.registerAsTable("jsonTable") + jsonSchemaRDD.registerTempTable("jsonTable") checkAnswer( sql("select * from jsonTable"), @@ -494,7 +494,7 @@ class JsonSuite extends QueryTest { assert(expectedSchema === jsonSchemaRDD.schema) - jsonSchemaRDD.registerAsTable("jsonTable") + jsonSchemaRDD.registerTempTable("jsonTable") } test("Loading a JSON dataset from a text file") { @@ -514,7 +514,7 @@ class JsonSuite extends QueryTest { assert(expectedSchema === jsonSchemaRDD.schema) - jsonSchemaRDD.registerAsTable("jsonTable") + jsonSchemaRDD.registerTempTable("jsonTable") checkAnswer( sql("select * from jsonTable"), @@ -546,7 +546,7 @@ class JsonSuite extends QueryTest { assert(schema === jsonSchemaRDD1.schema) - jsonSchemaRDD1.registerAsTable("jsonTable1") + jsonSchemaRDD1.registerTempTable("jsonTable1") checkAnswer( sql("select * from jsonTable1"), @@ -563,7 +563,7 @@ class JsonSuite extends QueryTest { assert(schema === jsonSchemaRDD2.schema) - jsonSchemaRDD2.registerAsTable("jsonTable2") + jsonSchemaRDD2.registerTempTable("jsonTable2") checkAnswer( sql("select * from jsonTable2"), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala index 8955455ec98c7..9933575038bd3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala @@ -101,9 +101,9 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA ParquetTestData.writeNestedFile3() ParquetTestData.writeNestedFile4() testRDD = parquetFile(ParquetTestData.testDir.toString) - testRDD.registerAsTable("testsource") + testRDD.registerTempTable("testsource") parquetFile(ParquetTestData.testFilterDir.toString) - .registerAsTable("testfiltersource") + .registerTempTable("testfiltersource") } override def afterAll() { @@ -247,7 +247,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA test("Creating case class RDD table") { TestSQLContext.sparkContext.parallelize((1 to 100)) .map(i => TestRDDEntry(i, s"val_$i")) - .registerAsTable("tmp") + .registerTempTable("tmp") val rdd = sql("SELECT * FROM tmp").collect().sortBy(_.getInt(0)) var counter = 1 rdd.foreach { @@ -266,7 +266,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA .map(i => TestRDDEntry(i, s"val_$i")) rdd.saveAsParquetFile(path) val readFile = parquetFile(path) - readFile.registerAsTable("tmpx") + readFile.registerTempTable("tmpx") val rdd_copy = sql("SELECT * FROM tmpx").collect() val rdd_orig = rdd.collect() for(i <- 0 to 99) { @@ -280,9 +280,9 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA val dirname = Utils.createTempDir() val source_rdd = TestSQLContext.sparkContext.parallelize((1 to 100)) .map(i => TestRDDEntry(i, s"val_$i")) - source_rdd.registerAsTable("source") + source_rdd.registerTempTable("source") val dest_rdd = createParquetFile[TestRDDEntry](dirname.toString) - dest_rdd.registerAsTable("dest") + dest_rdd.registerTempTable("dest") sql("INSERT OVERWRITE INTO dest SELECT * FROM source").collect() val rdd_copy1 = sql("SELECT * FROM dest").collect() assert(rdd_copy1.size === 100) @@ -547,7 +547,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA val data = nestedParserSqlContext .parquetFile(ParquetTestData.testNestedDir1.toString) .toSchemaRDD - data.registerAsTable("data") + data.registerTempTable("data") val query = nestedParserSqlContext.sql("SELECT owner, contacts[1].name FROM data") val tmp = query.collect() assert(tmp.size === 2) @@ -562,7 +562,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA val data = nestedParserSqlContext .parquetFile(ParquetTestData.testNestedDir2.toString) .toSchemaRDD - data.registerAsTable("data") + data.registerTempTable("data") val result1 = nestedParserSqlContext.sql("SELECT entries[0].value FROM data").collect() assert(result1.size === 1) assert(result1(0).size === 1) @@ -589,7 +589,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA val data = nestedParserSqlContext .parquetFile(ParquetTestData.testNestedDir3.toString) .toSchemaRDD - data.registerAsTable("data") + data.registerTempTable("data") val result1 = nestedParserSqlContext.sql("SELECT booleanNumberPairs[0].value[0].truth FROM data").collect() assert(result1.size === 1) assert(result1(0).size === 1) @@ -608,7 +608,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA val data = TestSQLContext .parquetFile(ParquetTestData.testNestedDir4.toString) .toSchemaRDD - data.registerAsTable("mapTable") + data.registerTempTable("mapTable") val result1 = sql("SELECT data1 FROM mapTable").collect() assert(result1.size === 1) assert(result1(0)(0) @@ -625,7 +625,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA val data = nestedParserSqlContext .parquetFile(ParquetTestData.testNestedDir4.toString) .toSchemaRDD - data.registerAsTable("mapTable") + data.registerTempTable("mapTable") val result1 = nestedParserSqlContext.sql("SELECT data2 FROM mapTable").collect() assert(result1.size === 1) val entry1 = result1(0)(0) @@ -658,7 +658,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA nestedParserSqlContext .parquetFile(tmpdir.toString) .toSchemaRDD - .registerAsTable("tmpcopy") + .registerTempTable("tmpcopy") val tmpdata = nestedParserSqlContext.sql("SELECT owner, contacts[1].name FROM tmpcopy").collect() assert(tmpdata.size === 2) assert(tmpdata(0).size === 2) @@ -679,7 +679,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA nestedParserSqlContext .parquetFile(tmpdir.toString) .toSchemaRDD - .registerAsTable("tmpmapcopy") + .registerTempTable("tmpmapcopy") val result1 = nestedParserSqlContext.sql("""SELECT data1["key2"] FROM tmpmapcopy""").collect() assert(result1.size === 1) assert(result1(0)(0) === 2) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala index 833f3502154f3..7e323146f9da2 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala @@ -28,7 +28,7 @@ case class TestData(key: Int, value: String) class InsertIntoHiveTableSuite extends QueryTest { val testData = TestHive.sparkContext.parallelize( (1 to 100).map(i => TestData(i, i.toString))) - testData.registerAsTable("testData") + testData.registerTempTable("testData") test("insertInto() HiveTable") { createTable[TestData]("createAndInsertTest") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/api/java/JavaHiveQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/api/java/JavaHiveQLSuite.scala index 10c8069a624e6..578f27574ad2f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/api/java/JavaHiveQLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/api/java/JavaHiveQLSuite.scala @@ -63,7 +63,7 @@ class JavaHiveQLSuite extends FunSuite { javaHiveCtx.hql(s"CREATE TABLE $tableName(key INT, value STRING)").count() } - javaHiveCtx.hql("SHOW TABLES").registerAsTable("show_tables") + javaHiveCtx.hql("SHOW TABLES").registerTempTable("show_tables") assert( javaHiveCtx @@ -73,7 +73,7 @@ class JavaHiveQLSuite extends FunSuite { .contains(tableName)) assertResult(Array(Array("key", "int", "None"), Array("value", "string", "None"))) { - javaHiveCtx.hql(s"DESCRIBE $tableName").registerAsTable("describe_table") + javaHiveCtx.hql(s"DESCRIBE $tableName").registerTempTable("describe_table") javaHiveCtx .hql("SELECT result FROM describe_table") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 89cc589fb8001..4ed41550cf530 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -247,7 +247,7 @@ class HiveQuerySuite extends HiveComparisonTest { TestHive.sparkContext.parallelize( TestData(1, "str1") :: TestData(2, "str2") :: Nil) - testData.registerAsTable("REGisteredTABle") + testData.registerTempTable("REGisteredTABle") assertResult(Array(Array(2, "str2"))) { hql("SELECT tablealias.A, TABLEALIAS.b FROM reGisteredTABle TableAlias " + @@ -272,7 +272,7 @@ class HiveQuerySuite extends HiveComparisonTest { test("SPARK-2180: HAVING support in GROUP BY clauses (positive)") { val fixture = List(("foo", 2), ("bar", 1), ("foo", 4), ("bar", 3)) .zipWithIndex.map {case Pair(Pair(value, attr), key) => HavingRow(key, value, attr)} - TestHive.sparkContext.parallelize(fixture).registerAsTable("having_test") + TestHive.sparkContext.parallelize(fixture).registerTempTable("having_test") val results = hql("SELECT value, max(attr) AS attr FROM having_test GROUP BY value HAVING attr > 3") .collect() @@ -401,7 +401,7 @@ class HiveQuerySuite extends HiveComparisonTest { TestHive.sparkContext.parallelize( TestData(1, "str1") :: TestData(1, "str2") :: Nil) - testData.registerAsTable("test_describe_commands2") + testData.registerTempTable("test_describe_commands2") assertResult( Array( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala index fb03db12a0b01..2455c18925dfa 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala @@ -54,14 +54,14 @@ class HiveResolutionSuite extends HiveComparisonTest { test("case insensitivity with scala reflection") { // Test resolution with Scala Reflection TestHive.sparkContext.parallelize(Data(1, 2, Nested(1,2), Seq(Nested(1,2))) :: Nil) - .registerAsTable("caseSensitivityTest") + .registerTempTable("caseSensitivityTest") hql("SELECT a, b, A, B, n.a, n.b, n.A, n.B FROM caseSensitivityTest") } test("nested repeated resolution") { TestHive.sparkContext.parallelize(Data(1, 2, Nested(1,2), Seq(Nested(1,2))) :: Nil) - .registerAsTable("nestedRepeatedTest") + .registerTempTable("nestedRepeatedTest") assert(hql("SELECT nestedArray[0].a FROM nestedRepeatedTest").collect().head(0) === 1) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala index 47526e3596e44..6545e8d7dcb69 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala @@ -41,7 +41,7 @@ class HiveParquetSuite extends FunSuite with BeforeAndAfterAll with BeforeAndAft // write test data ParquetTestData.writeFile() testRDD = parquetFile(ParquetTestData.testDir.toString) - testRDD.registerAsTable("testsource") + testRDD.registerTempTable("testsource") } override def afterAll() { @@ -67,7 +67,7 @@ class HiveParquetSuite extends FunSuite with BeforeAndAfterAll with BeforeAndAft .map(i => Cases(i, i)) .saveAsParquetFile(tempFile.getCanonicalPath) - parquetFile(tempFile.getCanonicalPath).registerAsTable("cases") + parquetFile(tempFile.getCanonicalPath).registerTempTable("cases") hql("SELECT upper FROM cases").collect().map(_.getString(0)) === (1 to 10).map(_.toString) hql("SELECT LOWER FROM cases").collect().map(_.getString(0)) === (1 to 10).map(_.toString) } @@ -86,7 +86,7 @@ class HiveParquetSuite extends FunSuite with BeforeAndAfterAll with BeforeAndAft test("Converting Hive to Parquet Table via saveAsParquetFile") { hql("SELECT * FROM src").saveAsParquetFile(dirname.getAbsolutePath) - parquetFile(dirname.getAbsolutePath).registerAsTable("ptable") + parquetFile(dirname.getAbsolutePath).registerTempTable("ptable") val rddOne = hql("SELECT * FROM src").collect().sortBy(_.getInt(0)) val rddTwo = hql("SELECT * from ptable").collect().sortBy(_.getInt(0)) compareRDDs(rddOne, rddTwo, "src (Hive)", Seq("key:Int", "value:String")) @@ -94,7 +94,7 @@ class HiveParquetSuite extends FunSuite with BeforeAndAfterAll with BeforeAndAft test("INSERT OVERWRITE TABLE Parquet table") { hql("SELECT * FROM testsource").saveAsParquetFile(dirname.getAbsolutePath) - parquetFile(dirname.getAbsolutePath).registerAsTable("ptable") + parquetFile(dirname.getAbsolutePath).registerTempTable("ptable") // let's do three overwrites for good measure hql("INSERT OVERWRITE TABLE ptable SELECT * FROM testsource").collect() hql("INSERT OVERWRITE TABLE ptable SELECT * FROM testsource").collect() From 0d47bb642f645c3c8663f4bdf869b5337ef9cb35 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Sat, 2 Aug 2014 21:44:19 -0700 Subject: [PATCH 011/231] SPARK-2602 [BUILD] Tests steal focus under Java 6 As per https://issues.apache.org/jira/browse/SPARK-2602 , this may be resolved for Java 6 with the java.awt.headless system property, which never hurt anyone running a command line app. I tested it and seemed to get rid of focus stealing. Author: Sean Owen Closes #1747 from srowen/SPARK-2602 and squashes the following commits: b141018 [Sean Owen] Set java.awt.headless during tests (cherry picked from commit 33f167d762483b55d5d874dcc1e3075f661d4375) Signed-off-by: Patrick Wendell --- pom.xml | 1 + 1 file changed, 1 insertion(+) diff --git a/pom.xml b/pom.xml index a42759169149b..cc9377cec2a07 100644 --- a/pom.xml +++ b/pom.xml @@ -871,6 +871,7 @@ -Xmx3g -XX:MaxPermSize=${MaxPermGen} -XX:ReservedCodeCacheSize=512m + true ${session.executionRootDirectory} 1 From c137928cbe74446254fdbd656c50c1a1c8930094 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Sat, 2 Aug 2014 21:55:56 -0700 Subject: [PATCH 012/231] SPARK-2414 [BUILD] Add LICENSE entry for jquery The JIRA concerned removing jquery, and this does not remove jquery. While it is distributed by Spark it should have an accompanying line in LICENSE, very technically, as per http://www.apache.org/dev/licensing-howto.html Author: Sean Owen Closes #1748 from srowen/SPARK-2414 and squashes the following commits: 2fdb03c [Sean Owen] Add LICENSE entry for jquery (cherry picked from commit 9cf429aaf529e91f619910c33cfe46bf33a66982) Signed-off-by: Patrick Wendell --- LICENSE | 1 + 1 file changed, 1 insertion(+) diff --git a/LICENSE b/LICENSE index 76a3601c66918..e9a1153fdc5db 100644 --- a/LICENSE +++ b/LICENSE @@ -549,3 +549,4 @@ The following components are provided under the MIT License. See project link fo (MIT License) pyrolite (org.spark-project:pyrolite:2.0.1 - http://pythonhosted.org/Pyro4/) (MIT License) scopt (com.github.scopt:scopt_2.10:3.2.0 - https://github.com/scopt/scopt) (The MIT License) Mockito (org.mockito:mockito-all:1.8.5 - http://www.mockito.org) + (MIT License) jquery (https://jquery.org/license/) From fb2a2079fa10ea8f338d68945a94238dda9fbd66 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Sat, 2 Aug 2014 22:00:46 -0700 Subject: [PATCH 013/231] [Minor] Fixes on top of #1679 Minor fixes on top of #1679. Author: Andrew Or Closes #1736 from andrewor14/amend-#1679 and squashes the following commits: 3b46f5e [Andrew Or] Minor fixes (cherry picked from commit 3dc55fdf450b4237f7c592fce56d1467fd206366) Signed-off-by: Patrick Wendell --- .../org/apache/spark/storage/BlockManagerSource.scala | 5 ++--- .../scala/org/apache/spark/storage/StorageUtils.scala | 11 ++++------- 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSource.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSource.scala index e939318a029dd..3f14c40ec61cb 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSource.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSource.scala @@ -46,9 +46,8 @@ private[spark] class BlockManagerSource(val blockManager: BlockManager, sc: Spar metricRegistry.register(MetricRegistry.name("memory", "memUsed_MB"), new Gauge[Long] { override def getValue: Long = { val storageStatusList = blockManager.master.getStorageStatus - val maxMem = storageStatusList.map(_.maxMem).sum - val remainingMem = storageStatusList.map(_.memRemaining).sum - (maxMem - remainingMem) / 1024 / 1024 + val memUsed = storageStatusList.map(_.memUsed).sum + memUsed / 1024 / 1024 } }) diff --git a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala index 0a0a448baa2ef..2bd6b749be261 100644 --- a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala +++ b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala @@ -172,16 +172,13 @@ class StorageStatus(val blockManagerId: BlockManagerId, val maxMem: Long) { def memRemaining: Long = maxMem - memUsed /** Return the memory used by this block manager. */ - def memUsed: Long = - _nonRddStorageInfo._1 + _rddBlocks.keys.toSeq.map(memUsedByRdd).sum + def memUsed: Long = _nonRddStorageInfo._1 + _rddBlocks.keys.toSeq.map(memUsedByRdd).sum /** Return the disk space used by this block manager. */ - def diskUsed: Long = - _nonRddStorageInfo._2 + _rddBlocks.keys.toSeq.map(diskUsedByRdd).sum + def diskUsed: Long = _nonRddStorageInfo._2 + _rddBlocks.keys.toSeq.map(diskUsedByRdd).sum /** Return the off-heap space used by this block manager. */ - def offHeapUsed: Long = - _nonRddStorageInfo._3 + _rddBlocks.keys.toSeq.map(offHeapUsedByRdd).sum + def offHeapUsed: Long = _nonRddStorageInfo._3 + _rddBlocks.keys.toSeq.map(offHeapUsedByRdd).sum /** Return the memory used by the given RDD in this block manager in O(1) time. */ def memUsedByRdd(rddId: Int): Long = _rddStorageInfo.get(rddId).map(_._1).getOrElse(0L) @@ -246,7 +243,7 @@ private[spark] object StorageUtils { val rddId = rddInfo.id // Assume all blocks belonging to the same RDD have the same storage level val storageLevel = statuses - .map(_.rddStorageLevel(rddId)).flatMap(s => s).headOption.getOrElse(StorageLevel.NONE) + .flatMap(_.rddStorageLevel(rddId)).headOption.getOrElse(StorageLevel.NONE) val numCachedPartitions = statuses.map(_.numRddBlocksById(rddId)).sum val memSize = statuses.map(_.memUsedByRdd(rddId)).sum val diskSize = statuses.map(_.diskUsedByRdd(rddId)).sum From 1992175fd93f0239e5a09e0b8db99ad9af7f380c Mon Sep 17 00:00:00 2001 From: Stephen Boesch Date: Sun, 3 Aug 2014 10:19:04 -0700 Subject: [PATCH 014/231] SPARK-2712 - Add a small note to maven doc that mvn package must happen ... Per request by Reynold adding small note about proper sequencing of build then test. Author: Stephen Boesch Closes #1615 from javadba/docs and squashes the following commits: 6c3183e [Stephen Boesch] Moved updated testing blurb per PWendell 5764757 [Stephen Boesch] SPARK-2712 - Add a small note to maven doc that mvn package must happen before test (cherry picked from commit f8cd143b6b1b4d8aac87c229e5af263b0319b3ea) Signed-off-by: Patrick Wendell --- docs/building-with-maven.md | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/docs/building-with-maven.md b/docs/building-with-maven.md index 55a9e37dfed83..672d0ef114f6d 100644 --- a/docs/building-with-maven.md +++ b/docs/building-with-maven.md @@ -98,7 +98,12 @@ mvn -Pyarn-alpha -Phadoop-2.3 -Dhadoop.version=2.3.0 -Dyarn.version=0.23.7 -Dski # Spark Tests in Maven -Tests are run by default via the [ScalaTest Maven plugin](http://www.scalatest.org/user_guide/using_the_scalatest_maven_plugin). Some of the require Spark to be packaged first, so always run `mvn package` with `-DskipTests` the first time. You can then run the tests with `mvn -Dhadoop.version=... test`. +Tests are run by default via the [ScalaTest Maven plugin](http://www.scalatest.org/user_guide/using_the_scalatest_maven_plugin). + +Some of the tests require Spark to be packaged first, so always run `mvn package` with `-DskipTests` the first time. The following is an example of a correct (build, test) sequence: + + mvn -Pyarn -Phadoop-2.3 -DskipTests -Phive clean package + mvn -Pyarn -Phadoop-2.3 -Phive test The ScalaTest plugin also supports running only a specific test suite as follows: From 162fc9512018e0c592b3aaa29d405f511461795a Mon Sep 17 00:00:00 2001 From: "Allan Douglas R. de Oliveira" Date: Sun, 3 Aug 2014 10:25:59 -0700 Subject: [PATCH 015/231] SPARK-2246: Add user-data option to EC2 scripts Author: Allan Douglas R. de Oliveira Closes #1186 from douglaz/spark_ec2_user_data and squashes the following commits: 94a36f9 [Allan Douglas R. de Oliveira] Added user data option to EC2 script (cherry picked from commit a0bcbc159e89be868ccc96175dbf1439461557e1) Signed-off-by: Patrick Wendell --- ec2/spark_ec2.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py index 02cfe4ec39c7d..0c2f85a3868f4 100755 --- a/ec2/spark_ec2.py +++ b/ec2/spark_ec2.py @@ -135,6 +135,10 @@ def parse_args(): "--master-opts", type="string", default="", help="Extra options to give to master through SPARK_MASTER_OPTS variable " + "(e.g -Dspark.worker.timeout=180)") + parser.add_option( + "--user-data", type="string", default="", + help="Path to a user-data file (most AMI's interpret this as an initialization script)") + (opts, args) = parser.parse_args() if len(args) != 2: @@ -274,6 +278,12 @@ def launch_cluster(conn, opts, cluster_name): if opts.key_pair is None: print >> stderr, "ERROR: Must provide a key pair name (-k) to use on instances." sys.exit(1) + + user_data_content = None + if opts.user_data: + with open(opts.user_data) as user_data_file: + user_data_content = user_data_file.read() + print "Setting up security groups..." master_group = get_or_make_group(conn, cluster_name + "-master") slave_group = get_or_make_group(conn, cluster_name + "-slaves") @@ -347,7 +357,8 @@ def launch_cluster(conn, opts, cluster_name): key_name=opts.key_pair, security_groups=[slave_group], instance_type=opts.instance_type, - block_device_map=block_map) + block_device_map=block_map, + user_data=user_data_content) my_req_ids += [req.id for req in slave_reqs] i += 1 @@ -398,7 +409,8 @@ def launch_cluster(conn, opts, cluster_name): placement=zone, min_count=num_slaves_this_zone, max_count=num_slaves_this_zone, - block_device_map=block_map) + block_device_map=block_map, + user_data=user_data_content) slave_nodes += slave_res.instances print "Launched %d slaves in %s, regid = %s" % (num_slaves_this_zone, zone, slave_res.id) From eaa93555a7f935b00a2f94a7fa50a12e11578bd7 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Sun, 3 Aug 2014 10:36:52 -0700 Subject: [PATCH 016/231] [SPARK-2197] [mllib] Java DecisionTree bug fix and easy-of-use Bug fix: Before, when an RDD was created in Java and passed to DecisionTree.train(), the fake class tag caused problems. * Fix: DecisionTree: Used new RDD.retag() method to allow passing RDDs from Java. Other improvements to Decision Trees for easy-of-use with Java: * impurity classes: Added instance() methods to help with Java interface. * Strategy: Added Java-friendly constructor --> Note: I removed quantileCalculationStrategy from the Java-friendly constructor since (a) it is a special class and (b) there is only 1 option currently. I suspect we will redo the API before the other options are included. CC: mengxr Author: Joseph K. Bradley Closes #1740 from jkbradley/dt-java-new and squashes the following commits: 0805dc6 [Joseph K. Bradley] Changed Strategy to use JavaConverters instead of JavaConversions 519b1b7 [Joseph K. Bradley] * Organized imports in JavaDecisionTreeSuite.java * Using JavaConverters instead of JavaConversions in DecisionTreeSuite.scala f7b5ca1 [Joseph K. Bradley] Improvements to make it easier to run DecisionTree from Java. * DecisionTree: Used new RDD.retag() method to allow passing RDDs from Java. * impurity classes: Added instance() methods to help with Java interface. * Strategy: Added Java-friendly constructor ** Note: I removed quantileCalculationStrategy from the Java-friendly constructor since (a) it is a special class and (b) there is only 1 option currently. I suspect we will redo the API before the other options are included. d78ada6 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-java 320853f [Joseph K. Bradley] Added JavaDecisionTreeSuite, partly written 13a585e [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-java f1a8283 [Joseph K. Bradley] Added old JavaDecisionTreeSuite, to be updated later 225822f [Joseph K. Bradley] Bug: In DecisionTree, the method sequentialBinSearchForOrderedCategoricalFeatureInClassification() indexed bins from 0 to (math.pow(2, featureCategories.toInt - 1) - 1). This upper bound is the bound for unordered categorical features, not ordered ones. The upper bound should be the arity (i.e., max value) of the feature. (cherry picked from commit 2998e38a942351974da36cb619e863c6f0316e7a) Signed-off-by: Xiangrui Meng --- .../spark/mllib/tree/DecisionTree.scala | 8 +- .../mllib/tree/configuration/Strategy.scala | 29 +++++ .../spark/mllib/tree/impurity/Entropy.scala | 7 ++ .../spark/mllib/tree/impurity/Gini.scala | 7 ++ .../spark/mllib/tree/impurity/Variance.scala | 7 ++ .../mllib/tree/JavaDecisionTreeSuite.java | 102 ++++++++++++++++++ .../spark/mllib/tree/DecisionTreeSuite.scala | 6 ++ 7 files changed, 162 insertions(+), 4 deletions(-) create mode 100644 mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 382e76a9b7cba..1d03e6e3b36cf 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -48,12 +48,12 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo def train(input: RDD[LabeledPoint]): DecisionTreeModel = { // Cache input RDD for speedup during multiple passes. - input.cache() + val retaggedInput = input.retag(classOf[LabeledPoint]).cache() logDebug("algo = " + strategy.algo) // Find the splits and the corresponding bins (interval between the splits) using a sample // of the input data. - val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) + val (splits, bins) = DecisionTree.findSplitsBins(retaggedInput, strategy) val numBins = bins(0).length logDebug("numBins = " + numBins) @@ -70,7 +70,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo // dummy value for top node (updated during first split calculation) val nodes = new Array[Node](maxNumNodes) // num features - val numFeatures = input.take(1)(0).features.size + val numFeatures = retaggedInput.take(1)(0).features.size // Calculate level for single group construction @@ -107,7 +107,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo logDebug("#####################################") // Find best split for all nodes at a level. - val splitsStatsForLevel = DecisionTree.findBestSplits(input, parentImpurities, + val splitsStatsForLevel = DecisionTree.findBestSplits(retaggedInput, parentImpurities, strategy, level, filters, splits, bins, maxLevelForSingleGroup) for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index fdad4f029aa99..4ee4bcd0bcbc7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -17,6 +17,8 @@ package org.apache.spark.mllib.tree.configuration +import scala.collection.JavaConverters._ + import org.apache.spark.annotation.Experimental import org.apache.spark.mllib.tree.impurity.Impurity import org.apache.spark.mllib.tree.configuration.Algo._ @@ -61,4 +63,31 @@ class Strategy ( val isMulticlassWithCategoricalFeatures = isMulticlassClassification && (categoricalFeaturesInfo.size > 0) + /** + * Java-friendly constructor. + * + * @param algo classification or regression + * @param impurity criterion used for information gain calculation + * @param maxDepth Maximum depth of the tree. + * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. + * @param numClassesForClassification number of classes for classification. Default value is 2 + * leads to binary classification + * @param maxBins maximum number of bins used for splitting features + * @param categoricalFeaturesInfo A map storing information about the categorical variables and + * the number of discrete values they take. For example, an entry + * (n -> k) implies the feature n is categorical with k categories + * 0, 1, 2, ... , k-1. It's important to note that features are + * zero-indexed. + */ + def this( + algo: Algo, + impurity: Impurity, + maxDepth: Int, + numClassesForClassification: Int, + maxBins: Int, + categoricalFeaturesInfo: java.util.Map[java.lang.Integer, java.lang.Integer]) { + this(algo, impurity, maxDepth, numClassesForClassification, maxBins, Sort, + categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap) + } + } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala index 9297c20596527..96d2471e1f88c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala @@ -66,4 +66,11 @@ object Entropy extends Impurity { @DeveloperApi override def calculate(count: Double, sum: Double, sumSquares: Double): Double = throw new UnsupportedOperationException("Entropy.calculate") + + /** + * Get this impurity instance. + * This is useful for passing impurity parameters to a Strategy in Java. + */ + def instance = this + } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala index 2874bcf496484..d586f449048bb 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala @@ -62,4 +62,11 @@ object Gini extends Impurity { @DeveloperApi override def calculate(count: Double, sum: Double, sumSquares: Double): Double = throw new UnsupportedOperationException("Gini.calculate") + + /** + * Get this impurity instance. + * This is useful for passing impurity parameters to a Strategy in Java. + */ + def instance = this + } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala index 698a1a2a8e899..f7d99a40eb380 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala @@ -53,4 +53,11 @@ object Variance extends Impurity { val squaredLoss = sumSquares - (sum * sum) / count squaredLoss / count } + + /** + * Get this impurity instance. + * This is useful for passing impurity parameters to a Strategy in Java. + */ + def instance = this + } diff --git a/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java b/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java new file mode 100644 index 0000000000000..2c281a1ee7157 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree; + +import java.io.Serializable; +import java.util.HashMap; +import java.util.List; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.tree.configuration.Algo; +import org.apache.spark.mllib.tree.configuration.Strategy; +import org.apache.spark.mllib.tree.impurity.Gini; +import org.apache.spark.mllib.tree.model.DecisionTreeModel; + + +public class JavaDecisionTreeSuite implements Serializable { + private transient JavaSparkContext sc; + + @Before + public void setUp() { + sc = new JavaSparkContext("local", "JavaDecisionTreeSuite"); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + } + + int validatePrediction(List validationData, DecisionTreeModel model) { + int numCorrect = 0; + for (LabeledPoint point: validationData) { + Double prediction = model.predict(point.features()); + if (prediction == point.label()) { + numCorrect++; + } + } + return numCorrect; + } + + @Test + public void runDTUsingConstructor() { + List arr = DecisionTreeSuite.generateCategoricalDataPointsAsJavaList(); + JavaRDD rdd = sc.parallelize(arr); + HashMap categoricalFeaturesInfo = new HashMap(); + categoricalFeaturesInfo.put(1, 2); // feature 1 has 2 categories + + int maxDepth = 4; + int numClasses = 2; + int maxBins = 100; + Strategy strategy = new Strategy(Algo.Classification(), Gini.instance(), maxDepth, numClasses, + maxBins, categoricalFeaturesInfo); + + DecisionTree learner = new DecisionTree(strategy); + DecisionTreeModel model = learner.train(rdd.rdd()); + + int numCorrect = validatePrediction(arr, model); + Assert.assertTrue(numCorrect == rdd.count()); + } + + @Test + public void runDTUsingStaticMethods() { + List arr = DecisionTreeSuite.generateCategoricalDataPointsAsJavaList(); + JavaRDD rdd = sc.parallelize(arr); + HashMap categoricalFeaturesInfo = new HashMap(); + categoricalFeaturesInfo.put(1, 2); // feature 1 has 2 categories + + int maxDepth = 4; + int numClasses = 2; + int maxBins = 100; + Strategy strategy = new Strategy(Algo.Classification(), Gini.instance(), maxDepth, numClasses, + maxBins, categoricalFeaturesInfo); + + DecisionTreeModel model = DecisionTree$.MODULE$.train(rdd.rdd(), strategy); + + int numCorrect = validatePrediction(arr, model); + Assert.assertTrue(numCorrect == rdd.count()); + } + +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 8665a00f3b356..70ca7c8a266f2 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.mllib.tree +import scala.collection.JavaConverters._ + import org.scalatest.FunSuite import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance} @@ -815,6 +817,10 @@ object DecisionTreeSuite { arr } + def generateCategoricalDataPointsAsJavaList(): java.util.List[LabeledPoint] = { + generateCategoricalDataPoints().toList.asJava + } + def generateCategoricalDataPointsForMulticlass(): Array[LabeledPoint] = { val arr = new Array[LabeledPoint](3000) for (i <- 0 until 3000) { From c5ed1deba6b3f3e597554a8d0f93f402ae62fab9 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Sun, 3 Aug 2014 12:28:29 -0700 Subject: [PATCH 017/231] [SPARK-2784][SQL] Deprecate hql() method in favor of a config option, 'spark.sql.dialect' Many users have reported being confused by the distinction between the `sql` and `hql` methods. Specifically, many users think that `sql(...)` cannot be used to read hive tables. In this PR I introduce a new configuration option `spark.sql.dialect` that picks which dialect with be used for parsing. For SQLContext this must be set to `sql`. In `HiveContext` it defaults to `hiveql` but can also be set to `sql`. The `hql` and `hiveql` methods continue to act the same but are now marked as deprecated. **This is a possibly breaking change for some users unless they set the dialect manually, though this is unlikely.** For example: `hiveContex.sql("SELECT 1")` will now throw a parsing exception by default. Author: Michael Armbrust Closes #1746 from marmbrus/sqlLanguageConf and squashes the following commits: ad375cc [Michael Armbrust] Merge remote-tracking branch 'apache/master' into sqlLanguageConf 20c43f8 [Michael Armbrust] override function instead of just setting the value 7e4ae93 [Michael Armbrust] Deprecate hql() method in favor of a config option, 'spark.sql.dialect' (cherry picked from commit 236dfac6769016e433b2f6517cda2d308dea74bc) Signed-off-by: Michael Armbrust --- .../sbt_app_hive/src/main/scala/HiveApp.scala | 8 +- docs/sql-programming-guide.md | 18 ++-- .../examples/sql/hive/HiveFromSpark.scala | 12 +-- python/pyspark/sql.py | 20 ++-- .../scala/org/apache/spark/sql/SQLConf.scala | 17 +++- .../org/apache/spark/sql/SQLContext.scala | 11 ++- .../spark/sql/api/java/JavaSQLContext.scala | 14 ++- .../hive/thriftserver/SparkSQLDriver.scala | 2 +- .../server/SparkSQLOperationManager.scala | 2 +- .../apache/spark/sql/hive/HiveContext.scala | 26 ++++-- .../sql/hive/api/java/JavaHiveContext.scala | 15 ++- .../spark/sql/hive/CachedTableSuite.scala | 14 +-- .../spark/sql/hive/StatisticsSuite.scala | 10 +- .../sql/hive/api/java/JavaHiveQLSuite.scala | 19 ++-- .../hive/execution/HiveComparisonTest.scala | 4 +- .../sql/hive/execution/HiveQuerySuite.scala | 93 ++++++++++--------- .../hive/execution/HiveResolutionSuite.scala | 6 +- .../execution/HiveTypeCoercionSuite.scala | 2 +- .../sql/hive/execution/HiveUdfSuite.scala | 10 +- .../sql/hive/execution/PruningSuite.scala | 2 +- .../spark/sql/parquet/HiveParquetSuite.scala | 27 +++--- 21 files changed, 199 insertions(+), 133 deletions(-) diff --git a/dev/audit-release/sbt_app_hive/src/main/scala/HiveApp.scala b/dev/audit-release/sbt_app_hive/src/main/scala/HiveApp.scala index a21410f3b9813..5111bc0adb772 100644 --- a/dev/audit-release/sbt_app_hive/src/main/scala/HiveApp.scala +++ b/dev/audit-release/sbt_app_hive/src/main/scala/HiveApp.scala @@ -37,10 +37,10 @@ object SparkSqlExample { val hiveContext = new HiveContext(sc) import hiveContext._ - hql("DROP TABLE IF EXISTS src") - hql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") - hql("LOAD DATA LOCAL INPATH 'data.txt' INTO TABLE src") - val results = hql("FROM src SELECT key, value WHERE key >= 0 AND KEY < 5").collect() + sql("DROP TABLE IF EXISTS src") + sql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") + sql("LOAD DATA LOCAL INPATH 'data.txt' INTO TABLE src") + val results = sql("FROM src SELECT key, value WHERE key >= 0 AND KEY < 5").collect() results.foreach(println) def test(f: => Boolean, failureMsg: String) = { diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 0465468084cee..cd6543945c385 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -495,11 +495,11 @@ directory. // sc is an existing SparkContext. val hiveContext = new org.apache.spark.sql.hive.HiveContext(sc) -hiveContext.hql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") -hiveContext.hql("LOAD DATA LOCAL INPATH 'examples/src/main/resources/kv1.txt' INTO TABLE src") +hiveContext.sql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") +hiveContext.sql("LOAD DATA LOCAL INPATH 'examples/src/main/resources/kv1.txt' INTO TABLE src") // Queries are expressed in HiveQL -hiveContext.hql("FROM src SELECT key, value").collect().foreach(println) +hiveContext.sql("FROM src SELECT key, value").collect().foreach(println) {% endhighlight %}
    @@ -515,11 +515,11 @@ expressed in HiveQL. // sc is an existing JavaSparkContext. JavaHiveContext hiveContext = new org.apache.spark.sql.hive.api.java.HiveContext(sc); -hiveContext.hql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)"); -hiveContext.hql("LOAD DATA LOCAL INPATH 'examples/src/main/resources/kv1.txt' INTO TABLE src"); +hiveContext.sql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)"); +hiveContext.sql("LOAD DATA LOCAL INPATH 'examples/src/main/resources/kv1.txt' INTO TABLE src"); // Queries are expressed in HiveQL. -Row[] results = hiveContext.hql("FROM src SELECT key, value").collect(); +Row[] results = hiveContext.sql("FROM src SELECT key, value").collect(); {% endhighlight %} @@ -537,11 +537,11 @@ expressed in HiveQL. from pyspark.sql import HiveContext hiveContext = HiveContext(sc) -hiveContext.hql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") -hiveContext.hql("LOAD DATA LOCAL INPATH 'examples/src/main/resources/kv1.txt' INTO TABLE src") +hiveContext.sql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") +hiveContext.sql("LOAD DATA LOCAL INPATH 'examples/src/main/resources/kv1.txt' INTO TABLE src") # Queries can be expressed in HiveQL. -results = hiveContext.hql("FROM src SELECT key, value").collect() +results = hiveContext.sql("FROM src SELECT key, value").collect() {% endhighlight %} diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala b/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala index 12530c8490b09..3423fac0ad303 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala @@ -34,20 +34,20 @@ object HiveFromSpark { val hiveContext = new HiveContext(sc) import hiveContext._ - hql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") - hql("LOAD DATA LOCAL INPATH 'src/main/resources/kv1.txt' INTO TABLE src") + sql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") + sql("LOAD DATA LOCAL INPATH 'src/main/resources/kv1.txt' INTO TABLE src") // Queries are expressed in HiveQL println("Result of 'SELECT *': ") - hql("SELECT * FROM src").collect.foreach(println) + sql("SELECT * FROM src").collect.foreach(println) // Aggregation queries are also supported. - val count = hql("SELECT COUNT(*) FROM src").collect().head.getLong(0) + val count = sql("SELECT COUNT(*) FROM src").collect().head.getLong(0) println(s"COUNT(*): $count") // The results of SQL queries are themselves RDDs and support all normal RDD functions. The // items in the RDD are of type Row, which allows you to access each column by ordinal. - val rddFromSql = hql("SELECT key, value FROM src WHERE key < 10 ORDER BY key") + val rddFromSql = sql("SELECT key, value FROM src WHERE key < 10 ORDER BY key") println("Result of RDD.map:") val rddAsStrings = rddFromSql.map { @@ -60,6 +60,6 @@ object HiveFromSpark { // Queries can then join RDD data with data stored in Hive. println("Result of SELECT *:") - hql("SELECT * FROM records r JOIN src s ON r.key = s.key").collect().foreach(println) + sql("SELECT * FROM records r JOIN src s ON r.key = s.key").collect().foreach(println) } } diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 42b738e112809..1a829c6fafe03 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -1291,16 +1291,20 @@ def _get_hive_ctx(self): def hiveql(self, hqlQuery): """ - Runs a query expressed in HiveQL, returning the result as - a L{SchemaRDD}. + DEPRECATED: Use sql() """ + warnings.warn("hiveql() is deprecated as the sql function now parses using HiveQL by" + + "default. The SQL dialect for parsing can be set using 'spark.sql.dialect'", + DeprecationWarning) return SchemaRDD(self._ssql_ctx.hiveql(hqlQuery), self) def hql(self, hqlQuery): """ - Runs a query expressed in HiveQL, returning the result as - a L{SchemaRDD}. + DEPRECATED: Use sql() """ + warnings.warn("hql() is deprecated as the sql function now parses using HiveQL by" + + "default. The SQL dialect for parsing can be set using 'spark.sql.dialect'", + DeprecationWarning) return self.hiveql(hqlQuery) @@ -1313,16 +1317,16 @@ class LocalHiveContext(HiveContext): >>> import os >>> hiveCtx = LocalHiveContext(sc) >>> try: - ... supress = hiveCtx.hql("DROP TABLE src") + ... supress = hiveCtx.sql("DROP TABLE src") ... except Exception: ... pass >>> kv1 = os.path.join(os.environ["SPARK_HOME"], ... 'examples/src/main/resources/kv1.txt') - >>> supress = hiveCtx.hql( + >>> supress = hiveCtx.sql( ... "CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") - >>> supress = hiveCtx.hql("LOAD DATA LOCAL INPATH '%s' INTO TABLE src" + >>> supress = hiveCtx.sql("LOAD DATA LOCAL INPATH '%s' INTO TABLE src" ... % kv1) - >>> results = hiveCtx.hql("FROM src SELECT value" + >>> results = hiveCtx.sql("FROM src SELECT value" ... ).map(lambda r: int(r.value.split('_')[1])) >>> num = results.count() >>> reduce_sum = results.reduce(lambda x, y: x + y) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 2d407077be303..40bfd55e95a12 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -29,6 +29,7 @@ object SQLConf { val SHUFFLE_PARTITIONS = "spark.sql.shuffle.partitions" val JOIN_BROADCAST_TABLES = "spark.sql.join.broadcastTables" val CODEGEN_ENABLED = "spark.sql.codegen" + val DIALECT = "spark.sql.dialect" object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" @@ -39,7 +40,7 @@ object SQLConf { * A trait that enables the setting and getting of mutable config parameters/hints. * * In the presence of a SQLContext, these can be set and queried by passing SET commands - * into Spark SQL's query functions (sql(), hql(), etc.). Otherwise, users of this trait can + * into Spark SQL's query functions (i.e. sql()). Otherwise, users of this trait can * modify the hints by programmatically calling the setters and getters of this trait. * * SQLConf is thread-safe (internally synchronized, so safe to be used in multiple threads). @@ -53,6 +54,20 @@ trait SQLConf { /** ************************ Spark SQL Params/Hints ******************* */ // TODO: refactor so that these hints accessors don't pollute the name space of SQLContext? + /** + * The SQL dialect that is used when parsing queries. This defaults to 'sql' which uses + * a simple SQL parser provided by Spark SQL. This is currently the only option for users of + * SQLContext. + * + * When using a HiveContext, this value defaults to 'hiveql', which uses the Hive 0.12.0 HiveQL + * parser. Users can change this to 'sql' if they want to run queries that aren't supported by + * HiveQL (e.g., SELECT 1). + * + * Note that the choice of dialect does not affect things like what tables are available or + * how query execution is performed. + */ + private[spark] def dialect: String = get(DIALECT, "sql") + /** When true tables cached using the in-memory columnar caching will be compressed. */ private[spark] def useCompression: Boolean = get(COMPRESS_CACHED, "false").toBoolean diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 567f4dca991b2..ecd5fbaa0b094 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -248,11 +248,18 @@ class SQLContext(@transient val sparkContext: SparkContext) } /** - * Executes a SQL query using Spark, returning the result as a SchemaRDD. + * Executes a SQL query using Spark, returning the result as a SchemaRDD. The dialect that is + * used for SQL parsing can be configured with 'spark.sql.dialect'. * * @group userf */ - def sql(sqlText: String): SchemaRDD = new SchemaRDD(this, parseSql(sqlText)) + def sql(sqlText: String): SchemaRDD = { + if (dialect == "sql") { + new SchemaRDD(this, parseSql(sqlText)) + } else { + sys.error(s"Unsupported SQL dialect: $dialect") + } + } /** Returns the specified table as a SchemaRDD */ def table(tableName: String): SchemaRDD = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala index dbaa16e8b0c68..150ff8a42063d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala @@ -39,10 +39,18 @@ class JavaSQLContext(val sqlContext: SQLContext) extends UDFRegistration { def this(sparkContext: JavaSparkContext) = this(new SQLContext(sparkContext.sc)) /** - * Executes a query expressed in SQL, returning the result as a JavaSchemaRDD + * Executes a SQL query using Spark, returning the result as a SchemaRDD. The dialect that is + * used for SQL parsing can be configured with 'spark.sql.dialect'. + * + * @group userf */ - def sql(sqlQuery: String): JavaSchemaRDD = - new JavaSchemaRDD(sqlContext, sqlContext.parseSql(sqlQuery)) + def sql(sqlText: String): JavaSchemaRDD = { + if (sqlContext.dialect == "sql") { + new JavaSchemaRDD(sqlContext, sqlContext.parseSql(sqlText)) + } else { + sys.error(s"Unsupported SQL dialect: $sqlContext.dialect") + } + } /** * :: Experimental :: diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala index d362d599d08ca..7463df1f47d43 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala @@ -55,7 +55,7 @@ private[hive] class SparkSQLDriver(val context: HiveContext = SparkSQLEnv.hiveCo override def run(command: String): CommandProcessorResponse = { // TODO unify the error code try { - val execution = context.executePlan(context.hql(command).logicalPlan) + val execution = context.executePlan(context.sql(command).logicalPlan) hiveResponse = execution.stringResult() tableSchema = getResultSetSchema(execution) new CommandProcessorResponse(0) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala index d4dadfd21d13f..dee092159dd4c 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala @@ -128,7 +128,7 @@ class SparkSQLOperationManager(hiveContext: HiveContext) extends OperationManage logInfo(s"Running query '$statement'") setState(OperationState.RUNNING) try { - result = hiveContext.hql(statement) + result = hiveContext.sql(statement) logDebug(result.queryExecution.toString()) val groupId = round(random * 1000000).toString hiveContext.sparkContext.setJobGroup(groupId, statement) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 3c70b3f0921a5..7db0159512610 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -71,15 +71,29 @@ class LocalHiveContext(sc: SparkContext) extends HiveContext(sc) { class HiveContext(sc: SparkContext) extends SQLContext(sc) { self => + // Change the default SQL dialect to HiveQL + override private[spark] def dialect: String = get(SQLConf.DIALECT, "hiveql") + override protected[sql] def executePlan(plan: LogicalPlan): this.QueryExecution = new this.QueryExecution { val logical = plan } - /** - * Executes a query expressed in HiveQL using Spark, returning the result as a SchemaRDD. - */ + override def sql(sqlText: String): SchemaRDD = { + // TODO: Create a framework for registering parsers instead of just hardcoding if statements. + if (dialect == "sql") { + super.sql(sqlText) + } else if (dialect == "hiveql") { + new SchemaRDD(this, HiveQl.parseSql(sqlText)) + } else { + sys.error(s"Unsupported SQL dialect: $dialect. Try 'sql' or 'hiveql'") + } + } + + @deprecated("hiveql() is deprecated as the sql function now parses using HiveQL by default. " + + s"The SQL dialect for parsing can be set using ${SQLConf.DIALECT}", "1.1") def hiveql(hqlQuery: String): SchemaRDD = new SchemaRDD(this, HiveQl.parseSql(hqlQuery)) - /** An alias for `hiveql`. */ + @deprecated("hql() is deprecated as the sql function now parses using HiveQL by default. " + + s"The SQL dialect for parsing can be set using ${SQLConf.DIALECT}", "1.1") def hql(hqlQuery: String): SchemaRDD = hiveql(hqlQuery) /** @@ -95,7 +109,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { // Circular buffer to hold what hive prints to STDOUT and ERR. Only printed when failures occur. @transient - protected val outputBuffer = new java.io.OutputStream { + protected lazy val outputBuffer = new java.io.OutputStream { var pos: Int = 0 var buffer = new Array[Int](10240) def write(i: Int): Unit = { @@ -125,7 +139,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { /** * SQLConf and HiveConf contracts: when the hive session is first initialized, params in * HiveConf will get picked up by the SQLConf. Additionally, any properties set by - * set() or a SET command inside hql() or sql() will be set in the SQLConf *as well as* + * set() or a SET command inside sql() will be set in the SQLConf *as well as* * in the HiveConf. */ @transient protected[hive] lazy val hiveconf = new HiveConf(classOf[SessionState]) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/api/java/JavaHiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/api/java/JavaHiveContext.scala index c9ee162191c96..a201d2349a2ef 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/api/java/JavaHiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/api/java/JavaHiveContext.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.hive.api.java import org.apache.spark.api.java.JavaSparkContext import org.apache.spark.sql.api.java.{JavaSQLContext, JavaSchemaRDD} +import org.apache.spark.sql.SQLConf import org.apache.spark.sql.hive.{HiveContext, HiveQl} /** @@ -28,9 +29,21 @@ class JavaHiveContext(sparkContext: JavaSparkContext) extends JavaSQLContext(spa override val sqlContext = new HiveContext(sparkContext) + override def sql(sqlText: String): JavaSchemaRDD = { + // TODO: Create a framework for registering parsers instead of just hardcoding if statements. + if (sqlContext.dialect == "sql") { + super.sql(sqlText) + } else if (sqlContext.dialect == "hiveql") { + new JavaSchemaRDD(sqlContext, HiveQl.parseSql(sqlText)) + } else { + sys.error(s"Unsupported SQL dialect: ${sqlContext.dialect}. Try 'sql' or 'hiveql'") + } + } + /** - * Executes a query expressed in HiveQL, returning the result as a JavaSchemaRDD. + * DEPRECATED: Use sql(...) Instead */ + @Deprecated def hql(hqlQuery: String): JavaSchemaRDD = new JavaSchemaRDD(sqlContext, HiveQl.parseSql(hqlQuery)) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala index 08da6405a17c6..188579edd7bdd 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala @@ -35,17 +35,17 @@ class CachedTableSuite extends HiveComparisonTest { "SELECT * FROM src LIMIT 1", reset = false) test("Drop cached table") { - hql("CREATE TABLE test(a INT)") + sql("CREATE TABLE test(a INT)") cacheTable("test") - hql("SELECT * FROM test").collect() - hql("DROP TABLE test") + sql("SELECT * FROM test").collect() + sql("DROP TABLE test") intercept[org.apache.hadoop.hive.ql.metadata.InvalidTableException] { - hql("SELECT * FROM test").collect() + sql("SELECT * FROM test").collect() } } test("DROP nonexistant table") { - hql("DROP TABLE IF EXISTS nonexistantTable") + sql("DROP TABLE IF EXISTS nonexistantTable") } test("check that table is cached and uncache") { @@ -74,14 +74,14 @@ class CachedTableSuite extends HiveComparisonTest { } test("'CACHE TABLE' and 'UNCACHE TABLE' HiveQL statement") { - TestHive.hql("CACHE TABLE src") + TestHive.sql("CACHE TABLE src") TestHive.table("src").queryExecution.executedPlan match { case _: InMemoryColumnarTableScan => // Found evidence of caching case _ => fail(s"Table 'src' should be cached") } assert(TestHive.isCached("src"), "Table 'src' should be cached") - TestHive.hql("UNCACHE TABLE src") + TestHive.sql("UNCACHE TABLE src") TestHive.table("src").queryExecution.executedPlan match { case _: InMemoryColumnarTableScan => fail(s"Table 'src' should not be cached") case _ => // Found evidence of uncaching diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index a61fd9df95c94..d8c77d6021d63 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.hive.test.TestHive._ class StatisticsSuite extends QueryTest { test("estimates the size of a test MetastoreRelation") { - val rdd = hql("""SELECT * FROM src""") + val rdd = sql("""SELECT * FROM src""") val sizes = rdd.queryExecution.analyzed.collect { case mr: MetastoreRelation => mr.statistics.sizeInBytes } @@ -45,7 +45,7 @@ class StatisticsSuite extends QueryTest { ct: ClassTag[_]) = { before() - var rdd = hql(query) + var rdd = sql(query) // Assert src has a size smaller than the threshold. val sizes = rdd.queryExecution.analyzed.collect { @@ -65,8 +65,8 @@ class StatisticsSuite extends QueryTest { TestHive.settings.synchronized { val tmp = autoBroadcastJoinThreshold - hql(s"""SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=-1""") - rdd = hql(query) + sql(s"""SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=-1""") + rdd = sql(query) bhj = rdd.queryExecution.sparkPlan.collect { case j: BroadcastHashJoin => j } assert(bhj.isEmpty, "BroadcastHashJoin still planned even though it is switched off") @@ -74,7 +74,7 @@ class StatisticsSuite extends QueryTest { assert(shj.size === 1, "ShuffledHashJoin should be planned when BroadcastHashJoin is turned off") - hql(s"""SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=$tmp""") + sql(s"""SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=$tmp""") } after() diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/api/java/JavaHiveQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/api/java/JavaHiveQLSuite.scala index 578f27574ad2f..9644b707eb1a0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/api/java/JavaHiveQLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/api/java/JavaHiveQLSuite.scala @@ -40,7 +40,7 @@ class JavaHiveQLSuite extends FunSuite { ignore("SELECT * FROM src") { assert( - javaHiveCtx.hql("SELECT * FROM src").collect().map(_.getInt(0)) === + javaHiveCtx.sql("SELECT * FROM src").collect().map(_.getInt(0)) === TestHive.sql("SELECT * FROM src").collect().map(_.getInt(0)).toSeq) } @@ -56,33 +56,34 @@ class JavaHiveQLSuite extends FunSuite { val tableName = "test_native_commands" assertResult(0) { - javaHiveCtx.hql(s"DROP TABLE IF EXISTS $tableName").count() + javaHiveCtx.sql(s"DROP TABLE IF EXISTS $tableName").count() } assertResult(0) { - javaHiveCtx.hql(s"CREATE TABLE $tableName(key INT, value STRING)").count() + javaHiveCtx.sql(s"CREATE TABLE $tableName(key INT, value STRING)").count() } - javaHiveCtx.hql("SHOW TABLES").registerTempTable("show_tables") + javaHiveCtx.sql("SHOW TABLES").registerTempTable("show_tables") assert( javaHiveCtx - .hql("SELECT result FROM show_tables") + .sql("SELECT result FROM show_tables") .collect() .map(_.getString(0)) .contains(tableName)) assertResult(Array(Array("key", "int", "None"), Array("value", "string", "None"))) { - javaHiveCtx.hql(s"DESCRIBE $tableName").registerTempTable("describe_table") + javaHiveCtx.sql(s"DESCRIBE $tableName").registerTempTable("describe_table") + javaHiveCtx - .hql("SELECT result FROM describe_table") + .sql("SELECT result FROM describe_table") .collect() .map(_.getString(0).split("\t").map(_.trim)) .toArray } - assert(isExplanation(javaHiveCtx.hql( + assert(isExplanation(javaHiveCtx.sql( s"EXPLAIN SELECT key, COUNT(*) FROM $tableName GROUP BY key"))) TestHive.reset() @@ -90,7 +91,7 @@ class JavaHiveQLSuite extends FunSuite { ignore("Exactly once semantics for DDL and command statements") { val tableName = "test_exactly_once" - val q0 = javaHiveCtx.hql(s"CREATE TABLE $tableName(key INT, value STRING)") + val q0 = javaHiveCtx.sql(s"CREATE TABLE $tableName(key INT, value STRING)") // If the table was not created, the following assertion would fail assert(Try(TestHive.table(tableName)).isSuccess) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index 83cfbc6b4a002..0ebaf6ffd5458 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -241,13 +241,13 @@ abstract class HiveComparisonTest val quotes = "\"\"\"" queryList.zipWithIndex.map { case (query, i) => - s"""val q$i = hql($quotes$query$quotes); q$i.collect()""" + s"""val q$i = sql($quotes$query$quotes); q$i.collect()""" }.mkString("\n== Console version of this test ==\n", "\n", "\n") } try { // MINOR HACK: You must run a query before calling reset the first time. - TestHive.hql("SHOW TABLES") + TestHive.sql("SHOW TABLES") if (reset) { TestHive.reset() } val hiveCacheFiles = queryList.zipWithIndex.map { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 4ed41550cf530..aa810a291231a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -57,8 +57,8 @@ class HiveQuerySuite extends HiveComparisonTest { """.stripMargin) test("CREATE TABLE AS runs once") { - hql("CREATE TABLE foo AS SELECT 1 FROM src LIMIT 1").collect() - assert(hql("SELECT COUNT(*) FROM foo").collect().head.getLong(0) === 1, + sql("CREATE TABLE foo AS SELECT 1 FROM src LIMIT 1").collect() + assert(sql("SELECT COUNT(*) FROM foo").collect().head.getLong(0) === 1, "Incorrect number of rows in created table") } @@ -72,12 +72,14 @@ class HiveQuerySuite extends HiveComparisonTest { "SELECT 2 / 1, 1 / 2, 1 / 3, 1 / COUNT(*) FROM src LIMIT 1") test("Query expressed in SQL") { + set("spark.sql.dialect", "sql") assert(sql("SELECT 1").collect() === Array(Seq(1))) + set("spark.sql.dialect", "hiveql") + } test("Query expressed in HiveQL") { - hql("FROM src SELECT key").collect() - hiveql("FROM src SELECT key").collect() + sql("FROM src SELECT key").collect() } createQueryTest("Constant Folding Optimization for AVG_SUM_COUNT", @@ -193,12 +195,12 @@ class HiveQuerySuite extends HiveComparisonTest { "SELECT * FROM src LATERAL VIEW explode(map(key+3,key+4)) D as k, v") test("sampling") { - hql("SELECT * FROM src TABLESAMPLE(0.1 PERCENT) s") + sql("SELECT * FROM src TABLESAMPLE(0.1 PERCENT) s") } test("SchemaRDD toString") { - hql("SHOW TABLES").toString - hql("SELECT * FROM src").toString + sql("SHOW TABLES").toString + sql("SELECT * FROM src").toString } createQueryTest("case statements with key #1", @@ -226,8 +228,8 @@ class HiveQuerySuite extends HiveComparisonTest { "SELECT (CASE WHEN key > 2 THEN 3 WHEN 2 > key THEN 2 ELSE 0 END) FROM src WHERE key < 15") test("implement identity function using case statement") { - val actual = hql("SELECT (CASE key WHEN key THEN key END) FROM src").collect().toSet - val expected = hql("SELECT key FROM src").collect().toSet + val actual = sql("SELECT (CASE key WHEN key THEN key END) FROM src").collect().toSet + val expected = sql("SELECT key FROM src").collect().toSet assert(actual === expected) } @@ -235,7 +237,7 @@ class HiveQuerySuite extends HiveComparisonTest { // See https://github.com/apache/spark/pull/1055#issuecomment-45820167 for a discussion. ignore("non-boolean conditions in a CaseWhen are illegal") { intercept[Exception] { - hql("SELECT (CASE WHEN key > 2 THEN 3 WHEN 1 THEN 2 ELSE 0 END) FROM src").collect() + sql("SELECT (CASE WHEN key > 2 THEN 3 WHEN 1 THEN 2 ELSE 0 END) FROM src").collect() } } @@ -250,7 +252,7 @@ class HiveQuerySuite extends HiveComparisonTest { testData.registerTempTable("REGisteredTABle") assertResult(Array(Array(2, "str2"))) { - hql("SELECT tablealias.A, TABLEALIAS.b FROM reGisteredTABle TableAlias " + + sql("SELECT tablealias.A, TABLEALIAS.b FROM reGisteredTABle TableAlias " + "WHERE TableAliaS.a > 1").collect() } } @@ -261,9 +263,9 @@ class HiveQuerySuite extends HiveComparisonTest { } test("SPARK-1704: Explain commands as a SchemaRDD") { - hql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") + sql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") - val rdd = hql("explain select key, count(value) from src group by key") + val rdd = sql("explain select key, count(value) from src group by key") assert(isExplanation(rdd)) TestHive.reset() @@ -274,7 +276,7 @@ class HiveQuerySuite extends HiveComparisonTest { .zipWithIndex.map {case Pair(Pair(value, attr), key) => HavingRow(key, value, attr)} TestHive.sparkContext.parallelize(fixture).registerTempTable("having_test") val results = - hql("SELECT value, max(attr) AS attr FROM having_test GROUP BY value HAVING attr > 3") + sql("SELECT value, max(attr) AS attr FROM having_test GROUP BY value HAVING attr > 3") .collect() .map(x => Pair(x.getString(0), x.getInt(1))) @@ -283,39 +285,39 @@ class HiveQuerySuite extends HiveComparisonTest { } test("SPARK-2180: HAVING with non-boolean clause raises no exceptions") { - hql("select key, count(*) c from src group by key having c").collect() + sql("select key, count(*) c from src group by key having c").collect() } test("SPARK-2225: turn HAVING without GROUP BY into a simple filter") { - assert(hql("select key from src having key > 490").collect().size < 100) + assert(sql("select key from src having key > 490").collect().size < 100) } test("Query Hive native command execution result") { val tableName = "test_native_commands" assertResult(0) { - hql(s"DROP TABLE IF EXISTS $tableName").count() + sql(s"DROP TABLE IF EXISTS $tableName").count() } assertResult(0) { - hql(s"CREATE TABLE $tableName(key INT, value STRING)").count() + sql(s"CREATE TABLE $tableName(key INT, value STRING)").count() } assert( - hql("SHOW TABLES") + sql("SHOW TABLES") .select('result) .collect() .map(_.getString(0)) .contains(tableName)) - assert(isExplanation(hql(s"EXPLAIN SELECT key, COUNT(*) FROM $tableName GROUP BY key"))) + assert(isExplanation(sql(s"EXPLAIN SELECT key, COUNT(*) FROM $tableName GROUP BY key"))) TestHive.reset() } test("Exactly once semantics for DDL and command statements") { val tableName = "test_exactly_once" - val q0 = hql(s"CREATE TABLE $tableName(key INT, value STRING)") + val q0 = sql(s"CREATE TABLE $tableName(key INT, value STRING)") // If the table was not created, the following assertion would fail assert(Try(table(tableName)).isSuccess) @@ -325,9 +327,9 @@ class HiveQuerySuite extends HiveComparisonTest { } test("DESCRIBE commands") { - hql(s"CREATE TABLE test_describe_commands1 (key INT, value STRING) PARTITIONED BY (dt STRING)") + sql(s"CREATE TABLE test_describe_commands1 (key INT, value STRING) PARTITIONED BY (dt STRING)") - hql( + sql( """FROM src INSERT OVERWRITE TABLE test_describe_commands1 PARTITION (dt='2008-06-08') |SELECT key, value """.stripMargin) @@ -342,7 +344,7 @@ class HiveQuerySuite extends HiveComparisonTest { Array("# col_name", "data_type", "comment"), Array("dt", "string", null)) ) { - hql("DESCRIBE test_describe_commands1") + sql("DESCRIBE test_describe_commands1") .select('col_name, 'data_type, 'comment) .collect() } @@ -357,14 +359,14 @@ class HiveQuerySuite extends HiveComparisonTest { Array("# col_name", "data_type", "comment"), Array("dt", "string", null)) ) { - hql("DESCRIBE default.test_describe_commands1") + sql("DESCRIBE default.test_describe_commands1") .select('col_name, 'data_type, 'comment) .collect() } // Describe a column is a native command assertResult(Array(Array("value", "string", "from deserializer"))) { - hql("DESCRIBE test_describe_commands1 value") + sql("DESCRIBE test_describe_commands1 value") .select('result) .collect() .map(_.getString(0).split("\t").map(_.trim)) @@ -372,7 +374,7 @@ class HiveQuerySuite extends HiveComparisonTest { // Describe a column is a native command assertResult(Array(Array("value", "string", "from deserializer"))) { - hql("DESCRIBE default.test_describe_commands1 value") + sql("DESCRIBE default.test_describe_commands1 value") .select('result) .collect() .map(_.getString(0).split("\t").map(_.trim)) @@ -390,7 +392,7 @@ class HiveQuerySuite extends HiveComparisonTest { Array("", "", ""), Array("dt", "string", "None")) ) { - hql("DESCRIBE test_describe_commands1 PARTITION (dt='2008-06-08')") + sql("DESCRIBE test_describe_commands1 PARTITION (dt='2008-06-08')") .select('result) .collect() .map(_.getString(0).split("\t").map(_.trim)) @@ -409,16 +411,16 @@ class HiveQuerySuite extends HiveComparisonTest { Array("a", "IntegerType", null), Array("b", "StringType", null)) ) { - hql("DESCRIBE test_describe_commands2") + sql("DESCRIBE test_describe_commands2") .select('col_name, 'data_type, 'comment) .collect() } } test("SPARK-2263: Insert Map values") { - hql("CREATE TABLE m(value MAP)") - hql("INSERT OVERWRITE TABLE m SELECT MAP(key, value) FROM src LIMIT 10") - hql("SELECT * FROM m").collect().zip(hql("SELECT * FROM src LIMIT 10").collect()).map { + sql("CREATE TABLE m(value MAP)") + sql("INSERT OVERWRITE TABLE m SELECT MAP(key, value) FROM src LIMIT 10") + sql("SELECT * FROM m").collect().zip(sql("SELECT * FROM src LIMIT 10").collect()).map { case (Row(map: Map[_, _]), Row(key: Int, value: String)) => assert(map.size === 1) assert(map.head === (key, value)) @@ -430,18 +432,18 @@ class HiveQuerySuite extends HiveComparisonTest { val testKey = "spark.sql.key.usedfortestonly" val testVal = "val0,val_1,val2.3,my_table" - hql(s"set $testKey=$testVal") + sql(s"set $testKey=$testVal") assert(get(testKey, testVal + "_") == testVal) - hql("set some.property=20") + sql("set some.property=20") assert(get("some.property", "0") == "20") - hql("set some.property = 40") + sql("set some.property = 40") assert(get("some.property", "0") == "40") - hql(s"set $testKey=$testVal") + sql(s"set $testKey=$testVal") assert(get(testKey, "0") == testVal) - hql(s"set $testKey=") + sql(s"set $testKey=") assert(get(testKey, "0") == "") } @@ -454,33 +456,34 @@ class HiveQuerySuite extends HiveComparisonTest { clear() // "set" itself returns all config variables currently specified in SQLConf. - assert(hql("SET").collect().size == 0) + // TODO: Should we be listing the default here always? probably... + assert(sql("SET").collect().size == 0) assertResult(Array(s"$testKey=$testVal")) { - hql(s"SET $testKey=$testVal").collect().map(_.getString(0)) + sql(s"SET $testKey=$testVal").collect().map(_.getString(0)) } assert(hiveconf.get(testKey, "") == testVal) assertResult(Array(s"$testKey=$testVal")) { - hql(s"SET $testKey=$testVal").collect().map(_.getString(0)) + sql(s"SET $testKey=$testVal").collect().map(_.getString(0)) } - hql(s"SET ${testKey + testKey}=${testVal + testVal}") + sql(s"SET ${testKey + testKey}=${testVal + testVal}") assert(hiveconf.get(testKey + testKey, "") == testVal + testVal) assertResult(Array(s"$testKey=$testVal", s"${testKey + testKey}=${testVal + testVal}")) { - hql(s"SET").collect().map(_.getString(0)) + sql(s"SET").collect().map(_.getString(0)) } // "set key" assertResult(Array(s"$testKey=$testVal")) { - hql(s"SET $testKey").collect().map(_.getString(0)) + sql(s"SET $testKey").collect().map(_.getString(0)) } assertResult(Array(s"$nonexistentKey=")) { - hql(s"SET $nonexistentKey").collect().map(_.getString(0)) + sql(s"SET $nonexistentKey").collect().map(_.getString(0)) } - // Assert that sql() should have the same effects as hql() by repeating the above using sql(). + // Assert that sql() should have the same effects as sql() by repeating the above using sql(). clear() assert(sql("SET").collect().size == 0) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala index 2455c18925dfa..6b3ffd1c0ffe2 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala @@ -56,13 +56,13 @@ class HiveResolutionSuite extends HiveComparisonTest { TestHive.sparkContext.parallelize(Data(1, 2, Nested(1,2), Seq(Nested(1,2))) :: Nil) .registerTempTable("caseSensitivityTest") - hql("SELECT a, b, A, B, n.a, n.b, n.A, n.B FROM caseSensitivityTest") + sql("SELECT a, b, A, B, n.a, n.b, n.A, n.B FROM caseSensitivityTest") } test("nested repeated resolution") { TestHive.sparkContext.parallelize(Data(1, 2, Nested(1,2), Seq(Nested(1,2))) :: Nil) - .registerTempTable("nestedRepeatedTest") - assert(hql("SELECT nestedArray[0].a FROM nestedRepeatedTest").collect().head(0) === 1) + .registerTempTable("nestedRepeatedTest") + assert(sql("SELECT nestedArray[0].a FROM nestedRepeatedTest").collect().head(0) === 1) } /** diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala index 7436de264a1e1..c3c18cf8ccac3 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala @@ -35,7 +35,7 @@ class HiveTypeCoercionSuite extends HiveComparisonTest { test("[SPARK-2210] boolean cast on boolean value should be removed") { val q = "select cast(cast(key=0 as boolean) as boolean) from src" - val project = TestHive.hql(q).queryExecution.executedPlan.collect { case e: Project => e }.head + val project = TestHive.sql(q).queryExecution.executedPlan.collect { case e: Project => e }.head // No cast expression introduced project.transformAllExpressions { case c: Cast => diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala index f944d010660eb..b6b8592344ef5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala @@ -37,7 +37,7 @@ import org.apache.hadoop.hive.ql.udf.generic.GenericUDF.DeferredObject */ class HiveUdfSuite extends HiveComparisonTest { - TestHive.hql( + TestHive.sql( """ |CREATE EXTERNAL TABLE hiveUdfTestTable ( | pair STRUCT @@ -48,16 +48,16 @@ class HiveUdfSuite extends HiveComparisonTest { """.stripMargin.format(classOf[PairSerDe].getName) ) - TestHive.hql( + TestHive.sql( "ALTER TABLE hiveUdfTestTable ADD IF NOT EXISTS PARTITION(partition='testUdf') LOCATION '%s'" .format(this.getClass.getClassLoader.getResource("data/files/testUdf").getFile) ) - TestHive.hql("CREATE TEMPORARY FUNCTION testUdf AS '%s'".format(classOf[PairUdf].getName)) + TestHive.sql("CREATE TEMPORARY FUNCTION testUdf AS '%s'".format(classOf[PairUdf].getName)) - TestHive.hql("SELECT testUdf(pair) FROM hiveUdfTestTable") + TestHive.sql("SELECT testUdf(pair) FROM hiveUdfTestTable") - TestHive.hql("DROP TEMPORARY FUNCTION IF EXISTS testUdf") + TestHive.sql("DROP TEMPORARY FUNCTION IF EXISTS testUdf") } class TestPair(x: Int, y: Int) extends Writable with Serializable { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala index 34d8a061ccc83..1a6dbc0ce0c0d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala @@ -27,7 +27,7 @@ import scala.collection.JavaConversions._ */ class PruningSuite extends HiveComparisonTest { // MINOR HACK: You must run a query before calling reset the first time. - TestHive.hql("SHOW TABLES") + TestHive.sql("SHOW TABLES") // Column/partition pruning is not implemented for `InMemoryColumnarTableScan` yet, need to reset // the environment to ensure all referenced tables in this suites are not cached in-memory. diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala index 6545e8d7dcb69..6f57fe8958387 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala @@ -68,39 +68,40 @@ class HiveParquetSuite extends FunSuite with BeforeAndAfterAll with BeforeAndAft .saveAsParquetFile(tempFile.getCanonicalPath) parquetFile(tempFile.getCanonicalPath).registerTempTable("cases") - hql("SELECT upper FROM cases").collect().map(_.getString(0)) === (1 to 10).map(_.toString) - hql("SELECT LOWER FROM cases").collect().map(_.getString(0)) === (1 to 10).map(_.toString) + sql("SELECT upper FROM cases").collect().map(_.getString(0)) === (1 to 10).map(_.toString) + sql("SELECT LOWER FROM cases").collect().map(_.getString(0)) === (1 to 10).map(_.toString) } test("SELECT on Parquet table") { - val rdd = hql("SELECT * FROM testsource").collect() + val rdd = sql("SELECT * FROM testsource").collect() assert(rdd != null) assert(rdd.forall(_.size == 6)) } test("Simple column projection + filter on Parquet table") { - val rdd = hql("SELECT myboolean, mylong FROM testsource WHERE myboolean=true").collect() + val rdd = sql("SELECT myboolean, mylong FROM testsource WHERE myboolean=true").collect() assert(rdd.size === 5, "Filter returned incorrect number of rows") assert(rdd.forall(_.getBoolean(0)), "Filter returned incorrect Boolean field value") } test("Converting Hive to Parquet Table via saveAsParquetFile") { - hql("SELECT * FROM src").saveAsParquetFile(dirname.getAbsolutePath) + sql("SELECT * FROM src").saveAsParquetFile(dirname.getAbsolutePath) parquetFile(dirname.getAbsolutePath).registerTempTable("ptable") - val rddOne = hql("SELECT * FROM src").collect().sortBy(_.getInt(0)) - val rddTwo = hql("SELECT * from ptable").collect().sortBy(_.getInt(0)) + val rddOne = sql("SELECT * FROM src").collect().sortBy(_.getInt(0)) + val rddTwo = sql("SELECT * from ptable").collect().sortBy(_.getInt(0)) + compareRDDs(rddOne, rddTwo, "src (Hive)", Seq("key:Int", "value:String")) } test("INSERT OVERWRITE TABLE Parquet table") { - hql("SELECT * FROM testsource").saveAsParquetFile(dirname.getAbsolutePath) + sql("SELECT * FROM testsource").saveAsParquetFile(dirname.getAbsolutePath) parquetFile(dirname.getAbsolutePath).registerTempTable("ptable") // let's do three overwrites for good measure - hql("INSERT OVERWRITE TABLE ptable SELECT * FROM testsource").collect() - hql("INSERT OVERWRITE TABLE ptable SELECT * FROM testsource").collect() - hql("INSERT OVERWRITE TABLE ptable SELECT * FROM testsource").collect() - val rddCopy = hql("SELECT * FROM ptable").collect() - val rddOrig = hql("SELECT * FROM testsource").collect() + sql("INSERT OVERWRITE TABLE ptable SELECT * FROM testsource").collect() + sql("INSERT OVERWRITE TABLE ptable SELECT * FROM testsource").collect() + sql("INSERT OVERWRITE TABLE ptable SELECT * FROM testsource").collect() + val rddCopy = sql("SELECT * FROM ptable").collect() + val rddOrig = sql("SELECT * FROM testsource").collect() assert(rddCopy.size === rddOrig.size, "INSERT OVERWRITE changed size of table??") compareRDDs(rddOrig, rddCopy, "testsource", ParquetTestData.testSchemaFieldNames) } From 6ffdcc61fb4825f991b754c45b807192f483a4a3 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Sun, 3 Aug 2014 12:34:46 -0700 Subject: [PATCH 018/231] [SPARK-2814][SQL] HiveThriftServer2 throws NPE when executing native commands JIRA issue: [SPARK-2814](https://issues.apache.org/jira/browse/SPARK-2814) Author: Cheng Lian Closes #1753 from liancheng/spark-2814 and squashes the following commits: c74a3b2 [Cheng Lian] Fixed SPARK-2814 (cherry picked from commit ac33cbbf33bd1ab29bc8165c9be02fb8934b1fdf) Signed-off-by: Michael Armbrust --- .../main/scala/org/apache/spark/sql/hive/HiveContext.scala | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 7db0159512610..acad681f68b14 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -146,13 +146,12 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { @transient protected[hive] lazy val sessionState = { val ss = new SessionState(hiveconf) set(hiveconf.getAllProperties) // Have SQLConf pick up the initial set of HiveConf. - - ss.err = new PrintStream(outputBuffer, true, "UTF-8") - ss.out = new PrintStream(outputBuffer, true, "UTF-8") - ss } + sessionState.err = new PrintStream(outputBuffer, true, "UTF-8") + sessionState.out = new PrintStream(outputBuffer, true, "UTF-8") + override def set(key: String, value: String): Unit = { super.set(key, value) runSqlHive(s"SET $key=$value") From 7c6afdac867d52447221438ed7508123c07d17f8 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Sun, 3 Aug 2014 14:54:41 -0700 Subject: [PATCH 019/231] [SPARK-2783][SQL] Basic support for analyze in HiveContext JIRA: https://issues.apache.org/jira/browse/SPARK-2783 Author: Yin Huai Closes #1741 from yhuai/analyzeTable and squashes the following commits: 7bb5f02 [Yin Huai] Use sql instead of hql. 4d09325 [Yin Huai] Merge remote-tracking branch 'upstream/master' into analyzeTable e3ebcd4 [Yin Huai] Renaming. c170f4e [Yin Huai] Do not use getContentSummary. 62393b6 [Yin Huai] Merge remote-tracking branch 'upstream/master' into analyzeTable db233a6 [Yin Huai] Trying to debug jenkins... fee84f0 [Yin Huai] Merge remote-tracking branch 'upstream/master' into analyzeTable f0501f3 [Yin Huai] Fix compilation error. 24ad391 [Yin Huai] Merge remote-tracking branch 'upstream/master' into analyzeTable 8918140 [Yin Huai] Wording. 23df227 [Yin Huai] Add a simple analyze method to get the size of a table and update the "totalSize" property of this table in the Hive metastore. (cherry picked from commit e139e2be60ef23281327744e1b3e74904dfdf63f) Signed-off-by: Michael Armbrust --- .../apache/spark/sql/hive/HiveContext.scala | 79 +++++++++++++++++++ .../spark/sql/hive/HiveMetastoreCatalog.scala | 5 +- .../spark/sql/hive/StatisticsSuite.scala | 54 +++++++++++++ 3 files changed, 136 insertions(+), 2 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index acad681f68b14..d8e7a5943daa5 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -25,10 +25,14 @@ import scala.collection.JavaConversions._ import scala.language.implicitConversions import scala.reflect.runtime.universe.{TypeTag, typeTag} +import org.apache.hadoop.fs.FileSystem +import org.apache.hadoop.fs.Path import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.ql.Driver +import org.apache.hadoop.hive.ql.metadata.Table import org.apache.hadoop.hive.ql.processors._ import org.apache.hadoop.hive.ql.session.SessionState +import org.apache.hadoop.hive.ql.stats.StatsSetupConst import org.apache.hadoop.hive.serde2.io.TimestampWritable import org.apache.spark.SparkContext @@ -107,6 +111,81 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { catalog.createTable("default", tableName, ScalaReflection.attributesFor[A], allowExisting) } + /** + * Analyzes the given table in the current database to generate statistics, which will be + * used in query optimizations. + * + * Right now, it only supports Hive tables and it only updates the size of a Hive table + * in the Hive metastore. + */ + def analyze(tableName: String) { + val relation = catalog.lookupRelation(None, tableName) match { + case LowerCaseSchema(r) => r + case o => o + } + + relation match { + case relation: MetastoreRelation => { + // This method is mainly based on + // org.apache.hadoop.hive.ql.stats.StatsUtils.getFileSizeForTable(HiveConf, Table) + // in Hive 0.13 (except that we do not use fs.getContentSummary). + // TODO: Generalize statistics collection. + // TODO: Why fs.getContentSummary returns wrong size on Jenkins? + // Can we use fs.getContentSummary in future? + // Seems fs.getContentSummary returns wrong table size on Jenkins. So we use + // countFileSize to count the table size. + def calculateTableSize(fs: FileSystem, path: Path): Long = { + val fileStatus = fs.getFileStatus(path) + val size = if (fileStatus.isDir) { + fs.listStatus(path).map(status => calculateTableSize(fs, status.getPath)).sum + } else { + fileStatus.getLen + } + + size + } + + def getFileSizeForTable(conf: HiveConf, table: Table): Long = { + val path = table.getPath() + var size: Long = 0L + try { + val fs = path.getFileSystem(conf) + size = calculateTableSize(fs, path) + } catch { + case e: Exception => + logWarning( + s"Failed to get the size of table ${table.getTableName} in the " + + s"database ${table.getDbName} because of ${e.toString}", e) + size = 0L + } + + size + } + + val tableParameters = relation.hiveQlTable.getParameters + val oldTotalSize = + Option(tableParameters.get(StatsSetupConst.TOTAL_SIZE)).map(_.toLong).getOrElse(0L) + val newTotalSize = getFileSizeForTable(hiveconf, relation.hiveQlTable) + // Update the Hive metastore if the total size of the table is different than the size + // recorded in the Hive metastore. + // This logic is based on org.apache.hadoop.hive.ql.exec.StatsTask.aggregateStats(). + if (newTotalSize > 0 && newTotalSize != oldTotalSize) { + tableParameters.put(StatsSetupConst.TOTAL_SIZE, newTotalSize.toString) + val hiveTTable = relation.hiveQlTable.getTTable + hiveTTable.setParameters(tableParameters) + val tableFullName = + relation.hiveQlTable.getDbName() + "." + relation.hiveQlTable.getTableName() + + catalog.client.alterTable(tableFullName, new Table(hiveTTable)) + } + } + case otherRelation => + throw new NotImplementedError( + s"Analyze has only implemented for Hive tables, " + + s"but ${tableName} is a ${otherRelation.nodeName}") + } + } + // Circular buffer to hold what hive prints to STDOUT and ERR. Only printed when failures occur. @transient protected lazy val outputBuffer = new java.io.OutputStream { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index df3604439e483..301cf51c00e2b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -25,6 +25,7 @@ import org.apache.hadoop.hive.metastore.api.{FieldSchema, StorageDescriptor, Ser import org.apache.hadoop.hive.metastore.api.{Table => TTable, Partition => TPartition} import org.apache.hadoop.hive.ql.metadata.{Hive, Partition, Table} import org.apache.hadoop.hive.ql.plan.TableDesc +import org.apache.hadoop.hive.ql.stats.StatsSetupConst import org.apache.hadoop.hive.serde2.Deserializer import org.apache.spark.annotation.DeveloperApi @@ -278,9 +279,9 @@ private[hive] case class MetastoreRelation // relatively cheap if parameters for the table are populated into the metastore. An // alternative would be going through Hadoop's FileSystem API, which can be expensive if a lot // of RPCs are involved. Besides `totalSize`, there are also `numFiles`, `numRows`, - // `rawDataSize` keys that we can look at in the future. + // `rawDataSize` keys (see StatsSetupConst in Hive) that we can look at in the future. BigInt( - Option(hiveQlTable.getParameters.get("totalSize")) + Option(hiveQlTable.getParameters.get(StatsSetupConst.TOTAL_SIZE)) .map(_.toLong) .getOrElse(sqlContext.defaultSizeInBytes)) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index d8c77d6021d63..bf5931bbf97ee 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -26,6 +26,60 @@ import org.apache.spark.sql.hive.test.TestHive._ class StatisticsSuite extends QueryTest { + test("analyze MetastoreRelations") { + def queryTotalSize(tableName: String): BigInt = + catalog.lookupRelation(None, tableName).statistics.sizeInBytes + + // Non-partitioned table + sql("CREATE TABLE analyzeTable (key STRING, value STRING)").collect() + sql("INSERT INTO TABLE analyzeTable SELECT * FROM src").collect() + sql("INSERT INTO TABLE analyzeTable SELECT * FROM src").collect() + + assert(queryTotalSize("analyzeTable") === defaultSizeInBytes) + + analyze("analyzeTable") + + assert(queryTotalSize("analyzeTable") === BigInt(11624)) + + sql("DROP TABLE analyzeTable").collect() + + // Partitioned table + sql( + """ + |CREATE TABLE analyzeTable_part (key STRING, value STRING) PARTITIONED BY (ds STRING) + """.stripMargin).collect() + sql( + """ + |INSERT INTO TABLE analyzeTable_part PARTITION (ds='2010-01-01') + |SELECT * FROM src + """.stripMargin).collect() + sql( + """ + |INSERT INTO TABLE analyzeTable_part PARTITION (ds='2010-01-02') + |SELECT * FROM src + """.stripMargin).collect() + sql( + """ + |INSERT INTO TABLE analyzeTable_part PARTITION (ds='2010-01-03') + |SELECT * FROM src + """.stripMargin).collect() + + assert(queryTotalSize("analyzeTable_part") === defaultSizeInBytes) + + analyze("analyzeTable_part") + + assert(queryTotalSize("analyzeTable_part") === BigInt(17436)) + + sql("DROP TABLE analyzeTable_part").collect() + + // Try to analyze a temp table + sql("""SELECT * FROM src""").registerTempTable("tempTable") + intercept[NotImplementedError] { + analyze("tempTable") + } + catalog.unregisterTable(None, "tempTable") + } + test("estimates the size of a test MetastoreRelation") { val rdd = sql("""SELECT * FROM src""") val sizes = rdd.queryExecution.analyzed.collect { case mr: MetastoreRelation => From a4cdb77e5ee2c80967a7b6cd7370170fabe56cd2 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Sun, 3 Aug 2014 15:52:00 -0700 Subject: [PATCH 020/231] [SPARK-1740] [PySpark] kill the python worker Kill only the python worker related to cancelled tasks. The daemon will start a background thread to monitor all the opened sockets for all workers. If the socket is closed by JVM, this thread will kill the worker. When an task is cancelled, the socket to worker will be closed, then the worker will be killed by deamon. Author: Davies Liu Closes #1643 from davies/kill and squashes the following commits: 8ffe9f3 [Davies Liu] kill worker by deamon, because runtime.exec() is too heavy 46ca150 [Davies Liu] address comment acd751c [Davies Liu] kill the worker when task is canceled (cherry picked from commit 55349f9fe81ba5af5e4a5e4908ebf174e63c6cc9) Signed-off-by: Josh Rosen --- .../scala/org/apache/spark/SparkEnv.scala | 5 +- .../apache/spark/api/python/PythonRDD.scala | 9 ++- .../api/python/PythonWorkerFactory.scala | 64 ++++++++++++++----- python/pyspark/daemon.py | 24 +++++-- python/pyspark/tests.py | 51 +++++++++++++++ 5 files changed, 125 insertions(+), 28 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 92c809d854167..0bce531aaba3e 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -18,6 +18,7 @@ package org.apache.spark import java.io.File +import java.net.Socket import scala.collection.JavaConversions._ import scala.collection.mutable @@ -102,10 +103,10 @@ class SparkEnv ( } private[spark] - def destroyPythonWorker(pythonExec: String, envVars: Map[String, String]) { + def destroyPythonWorker(pythonExec: String, envVars: Map[String, String], worker: Socket) { synchronized { val key = (pythonExec, envVars) - pythonWorkers(key).stop() + pythonWorkers.get(key).foreach(_.stopWorker(worker)) } } } diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index fe9a9e50ef21d..0b5322c6fb965 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -62,8 +62,8 @@ private[spark] class PythonRDD( val env = SparkEnv.get val localdir = env.blockManager.diskBlockManager.localDirs.map( f => f.getPath()).mkString(",") - val worker: Socket = env.createPythonWorker(pythonExec, - envVars.toMap + ("SPARK_LOCAL_DIR" -> localdir)) + envVars += ("SPARK_LOCAL_DIR" -> localdir) // it's also used in monitor thread + val worker: Socket = env.createPythonWorker(pythonExec, envVars.toMap) // Start a thread to feed the process input from our parent's iterator val writerThread = new WriterThread(env, worker, split, context) @@ -241,7 +241,7 @@ private[spark] class PythonRDD( if (!context.completed) { try { logWarning("Incomplete task interrupted: Attempting to kill Python Worker") - env.destroyPythonWorker(pythonExec, envVars.toMap) + env.destroyPythonWorker(pythonExec, envVars.toMap, worker) } catch { case e: Exception => logError("Exception when trying to kill worker", e) @@ -685,9 +685,8 @@ private[spark] object PythonRDD extends Logging { /** * Convert an RDD of serialized Python dictionaries to Scala Maps (no recursive conversions). - * This function is outdated, PySpark does not use it anymore */ - @deprecated + @deprecated("PySpark does not use it anymore", "1.1") def pythonToJavaMap(pyRDD: JavaRDD[Array[Byte]]): JavaRDD[Map[String, _]] = { pyRDD.rdd.mapPartitions { iter => val unpickle = new Unpickler diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala index 15fe8a9be6bfe..7af260d0b7f26 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala @@ -17,9 +17,11 @@ package org.apache.spark.api.python -import java.io.{DataInputStream, InputStream, OutputStreamWriter} +import java.lang.Runtime +import java.io.{DataOutputStream, DataInputStream, InputStream, OutputStreamWriter} import java.net.{InetAddress, ServerSocket, Socket, SocketException} +import scala.collection.mutable import scala.collection.JavaConversions._ import org.apache.spark._ @@ -39,6 +41,9 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String var daemon: Process = null val daemonHost = InetAddress.getByAddress(Array(127, 0, 0, 1)) var daemonPort: Int = 0 + var daemonWorkers = new mutable.WeakHashMap[Socket, Int]() + + var simpleWorkers = new mutable.WeakHashMap[Socket, Process]() val pythonPath = PythonUtils.mergePythonPaths( PythonUtils.sparkPythonPath, @@ -58,25 +63,31 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String * to avoid the high cost of forking from Java. This currently only works on UNIX-based systems. */ private def createThroughDaemon(): Socket = { + + def createSocket(): Socket = { + val socket = new Socket(daemonHost, daemonPort) + val pid = new DataInputStream(socket.getInputStream).readInt() + if (pid < 0) { + throw new IllegalStateException("Python daemon failed to launch worker") + } + daemonWorkers.put(socket, pid) + socket + } + synchronized { // Start the daemon if it hasn't been started startDaemon() // Attempt to connect, restart and retry once if it fails try { - val socket = new Socket(daemonHost, daemonPort) - val launchStatus = new DataInputStream(socket.getInputStream).readInt() - if (launchStatus != 0) { - throw new IllegalStateException("Python daemon failed to launch worker") - } - socket + createSocket() } catch { case exc: SocketException => logWarning("Failed to open socket to Python daemon:", exc) logWarning("Assuming that daemon unexpectedly quit, attempting to restart") stopDaemon() startDaemon() - new Socket(daemonHost, daemonPort) + createSocket() } } } @@ -107,7 +118,9 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String // Wait for it to connect to our socket serverSocket.setSoTimeout(10000) try { - return serverSocket.accept() + val socket = serverSocket.accept() + simpleWorkers.put(socket, worker) + return socket } catch { case e: Exception => throw new SparkException("Python worker did not connect back in time", e) @@ -189,19 +202,40 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String private def stopDaemon() { synchronized { - // Request shutdown of existing daemon by sending SIGTERM - if (daemon != null) { - daemon.destroy() - } + if (useDaemon) { + // Request shutdown of existing daemon by sending SIGTERM + if (daemon != null) { + daemon.destroy() + } - daemon = null - daemonPort = 0 + daemon = null + daemonPort = 0 + } else { + simpleWorkers.mapValues(_.destroy()) + } } } def stop() { stopDaemon() } + + def stopWorker(worker: Socket) { + if (useDaemon) { + if (daemon != null) { + daemonWorkers.get(worker).foreach { pid => + // tell daemon to kill worker by pid + val output = new DataOutputStream(daemon.getOutputStream) + output.writeInt(pid) + output.flush() + daemon.getOutputStream.flush() + } + } + } else { + simpleWorkers.get(worker).foreach(_.destroy()) + } + worker.close() + } } private object PythonWorkerFactory { diff --git a/python/pyspark/daemon.py b/python/pyspark/daemon.py index 9fde0dde0f4b4..b00da833d06f1 100644 --- a/python/pyspark/daemon.py +++ b/python/pyspark/daemon.py @@ -26,7 +26,7 @@ from socket import AF_INET, SOCK_STREAM, SOMAXCONN from signal import SIGHUP, SIGTERM, SIGCHLD, SIG_DFL, SIG_IGN from pyspark.worker import main as worker_main -from pyspark.serializers import write_int +from pyspark.serializers import read_int, write_int def compute_real_exit_code(exit_code): @@ -67,7 +67,8 @@ def waitSocketClose(sock): outfile = os.fdopen(os.dup(sock.fileno()), "a+", 65536) exit_code = 0 try: - write_int(0, outfile) # Acknowledge that the fork was successful + # Acknowledge that the fork was successful + write_int(os.getpid(), outfile) outfile.flush() worker_main(infile, outfile) except SystemExit as exc: @@ -125,14 +126,23 @@ def handle_sigchld(*args): else: raise if 0 in ready_fds: - # Spark told us to exit by closing stdin - shutdown(0) + try: + worker_pid = read_int(sys.stdin) + except EOFError: + # Spark told us to exit by closing stdin + shutdown(0) + try: + os.kill(worker_pid, signal.SIGKILL) + except OSError: + pass # process already died + + if listen_sock in ready_fds: sock, addr = listen_sock.accept() # Launch a worker process try: - fork_return_code = os.fork() - if fork_return_code == 0: + pid = os.fork() + if pid == 0: listen_sock.close() try: worker(sock) @@ -143,11 +153,13 @@ def handle_sigchld(*args): os._exit(0) else: sock.close() + except OSError as e: print >> sys.stderr, "Daemon failed to fork PySpark worker: %s" % e outfile = os.fdopen(os.dup(sock.fileno()), "a+", 65536) write_int(-1, outfile) # Signal that the fork failed outfile.flush() + outfile.close() sock.close() finally: shutdown(1) diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 16fb5a9256220..acc3c30371621 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -790,6 +790,57 @@ def test_termination_sigterm(self): self.do_termination_test(lambda daemon: os.kill(daemon.pid, SIGTERM)) +class TestWorker(PySparkTestCase): + def test_cancel_task(self): + temp = tempfile.NamedTemporaryFile(delete=True) + temp.close() + path = temp.name + def sleep(x): + import os, time + with open(path, 'w') as f: + f.write("%d %d" % (os.getppid(), os.getpid())) + time.sleep(100) + + # start job in background thread + def run(): + self.sc.parallelize(range(1)).foreach(sleep) + import threading + t = threading.Thread(target=run) + t.daemon = True + t.start() + + daemon_pid, worker_pid = 0, 0 + while True: + if os.path.exists(path): + data = open(path).read().split(' ') + daemon_pid, worker_pid = map(int, data) + break + time.sleep(0.1) + + # cancel jobs + self.sc.cancelAllJobs() + t.join() + + for i in range(50): + try: + os.kill(worker_pid, 0) + time.sleep(0.1) + except OSError: + break # worker was killed + else: + self.fail("worker has not been killed after 5 seconds") + + try: + os.kill(daemon_pid, 0) + except OSError: + self.fail("daemon had been killed") + + def test_fd_leak(self): + N = 1100 # fd limit is 1024 by default + rdd = self.sc.parallelize(range(N), N) + self.assertEquals(N, rdd.count()) + + class TestSparkSubmit(unittest.TestCase): def setUp(self): self.programDir = tempfile.mkdtemp() From 4784d24eadea2e1adf69d8fe4891bdce29188dd6 Mon Sep 17 00:00:00 2001 From: Anand Avati Date: Sun, 3 Aug 2014 17:47:49 -0700 Subject: [PATCH 021/231] [SPARK-2810] upgrade to scala-maven-plugin 3.2.0 Needed for Scala 2.11 compiler-interface Signed-off-by: Anand Avati Author: Anand Avati Closes #1711 from avati/SPARK-1812-scala-maven-plugin and squashes the following commits: 9a22fc8 [Anand Avati] SPARK-1812: upgrade to scala-maven-plugin 3.2.0 --- pom.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pom.xml b/pom.xml index cc9377cec2a07..4ab027bad55c0 100644 --- a/pom.xml +++ b/pom.xml @@ -782,7 +782,7 @@ net.alchim31.maven scala-maven-plugin - 3.1.6 + 3.2.0 scala-compile-first From 2152e24d64d6a07cf6c550c9f13ab0231596be98 Mon Sep 17 00:00:00 2001 From: Sarah Gerweck Date: Sun, 3 Aug 2014 19:47:05 -0700 Subject: [PATCH 022/231] Fix some bugs with spaces in directory name. Any time you use the directory name (`FWDIR`) it needs to be surrounded in quotes. If you're also using wildcards, you can safely put the quotes around just `$FWDIR`. Author: Sarah Gerweck Closes #1756 from sarahgerweck/folderSpaces and squashes the following commits: 732629d [Sarah Gerweck] Fix some bugs with spaces in directory name. (cherry picked from commit 5507dd8e18fbb52d5e0c64a767103b2418cb09c6) Signed-off-by: Patrick Wendell --- make-distribution.sh | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/make-distribution.sh b/make-distribution.sh index 1441497b3995a..f7a6a9d838bb6 100755 --- a/make-distribution.sh +++ b/make-distribution.sh @@ -168,22 +168,22 @@ mkdir -p "$DISTDIR/lib" echo "Spark $VERSION$GITREVSTRING built for Hadoop $SPARK_HADOOP_VERSION" > "$DISTDIR/RELEASE" # Copy jars -cp $FWDIR/assembly/target/scala*/*assembly*hadoop*.jar "$DISTDIR/lib/" -cp $FWDIR/examples/target/scala*/spark-examples*.jar "$DISTDIR/lib/" +cp "$FWDIR"/assembly/target/scala*/*assembly*hadoop*.jar "$DISTDIR/lib/" +cp "$FWDIR"/examples/target/scala*/spark-examples*.jar "$DISTDIR/lib/" # Copy example sources (needed for python and SQL) mkdir -p "$DISTDIR/examples/src/main" -cp -r $FWDIR/examples/src/main "$DISTDIR/examples/src/" +cp -r "$FWDIR"/examples/src/main "$DISTDIR/examples/src/" if [ "$SPARK_HIVE" == "true" ]; then - cp $FWDIR/lib_managed/jars/datanucleus*.jar "$DISTDIR/lib/" + cp "$FWDIR"/lib_managed/jars/datanucleus*.jar "$DISTDIR/lib/" fi # Copy license and ASF files cp "$FWDIR/LICENSE" "$DISTDIR" cp "$FWDIR/NOTICE" "$DISTDIR" -if [ -e $FWDIR/CHANGES.txt ]; then +if [ -e "$FWDIR"/CHANGES.txt ]; then cp "$FWDIR/CHANGES.txt" "$DISTDIR" fi From 9aa14598f89bb8b908222e37f965178d39c34fe6 Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Sun, 3 Aug 2014 21:39:21 -0700 Subject: [PATCH 023/231] SPARK-2272 [MLlib] Feature scaling which standardizes the range of independent variables or features of data Feature scaling is a method used to standardize the range of independent variables or features of data. In data processing, it is generally performed during the data preprocessing step. In this work, a trait called `VectorTransformer` is defined for generic transformation on a vector. It contains one method to be implemented, `transform` which applies transformation on a vector. There are two implementations of `VectorTransformer` now, and they all can be easily extended with PMML transformation support. 1) `StandardScaler` - Standardizes features by removing the mean and scaling to unit variance using column summary statistics on the samples in the training set. 2) `Normalizer` - Normalizes samples individually to unit L^n norm Author: DB Tsai Closes #1207 from dbtsai/dbtsai-feature-scaling and squashes the following commits: 78c15d3 [DB Tsai] Alpine Data Labs (cherry picked from commit ae58aea2d1435b5bb011e68127e1bcddc2edf5b2) Signed-off-by: Xiangrui Meng --- .../spark/mllib/feature/Normalizer.scala | 76 +++++++ .../spark/mllib/feature/StandardScaler.scala | 119 +++++++++++ .../mllib/feature/VectorTransformer.scala | 51 +++++ .../mllib/linalg/distributed/RowMatrix.scala | 2 +- .../spark/mllib/feature/NormalizerSuite.scala | 120 +++++++++++ .../mllib/feature/StandardScalerSuite.scala | 200 ++++++++++++++++++ 6 files changed, 567 insertions(+), 1 deletion(-) create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/feature/VectorTransformer.scala create mode 100644 mllib/src/test/scala/org/apache/spark/mllib/feature/NormalizerSuite.scala create mode 100644 mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala new file mode 100644 index 0000000000000..ea9fd0a80d8e0 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.feature + +import breeze.linalg.{DenseVector => BDV, SparseVector => BSV} + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.mllib.linalg.{Vector, Vectors} + +/** + * :: DeveloperApi :: + * Normalizes samples individually to unit L^p^ norm + * + * For any 1 <= p < Double.PositiveInfinity, normalizes samples using + * sum(abs(vector).^p^)^(1/p)^ as norm. + * + * For p = Double.PositiveInfinity, max(abs(vector)) will be used as norm for normalization. + * + * @param p Normalization in L^p^ space, p = 2 by default. + */ +@DeveloperApi +class Normalizer(p: Double) extends VectorTransformer { + + def this() = this(2) + + require(p >= 1.0) + + /** + * Applies unit length normalization on a vector. + * + * @param vector vector to be normalized. + * @return normalized vector. If the norm of the input is zero, it will return the input vector. + */ + override def transform(vector: Vector): Vector = { + var norm = vector.toBreeze.norm(p) + + if (norm != 0.0) { + // For dense vector, we've to allocate new memory for new output vector. + // However, for sparse vector, the `index` array will not be changed, + // so we can re-use it to save memory. + vector.toBreeze match { + case dv: BDV[Double] => Vectors.fromBreeze(dv :/ norm) + case sv: BSV[Double] => + val output = new BSV[Double](sv.index, sv.data.clone(), sv.length) + var i = 0 + while (i < output.data.length) { + output.data(i) /= norm + i += 1 + } + Vectors.fromBreeze(output) + case v => throw new IllegalArgumentException("Do not support vector type " + v.getClass) + } + } else { + // Since the norm is zero, return the input vector object itself. + // Note that it's safe since we always assume that the data in RDD + // should be immutable. + vector + } + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala new file mode 100644 index 0000000000000..cc2d7579c2901 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.feature + +import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, Vector => BV} + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.rdd.RDDFunctions._ +import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer +import org.apache.spark.rdd.RDD + +/** + * :: DeveloperApi :: + * Standardizes features by removing the mean and scaling to unit variance using column summary + * statistics on the samples in the training set. + * + * @param withMean False by default. Centers the data with mean before scaling. It will build a + * dense output, so this does not work on sparse input and will raise an exception. + * @param withStd True by default. Scales the data to unit standard deviation. + */ +@DeveloperApi +class StandardScaler(withMean: Boolean, withStd: Boolean) extends VectorTransformer { + + def this() = this(false, true) + + require(withMean || withStd, s"withMean and withStd both equal to false. Doing nothing.") + + private var mean: BV[Double] = _ + private var factor: BV[Double] = _ + + /** + * Computes the mean and variance and stores as a model to be used for later scaling. + * + * @param data The data used to compute the mean and variance to build the transformation model. + * @return This StandardScalar object. + */ + def fit(data: RDD[Vector]): this.type = { + val summary = data.treeAggregate(new MultivariateOnlineSummarizer)( + (aggregator, data) => aggregator.add(data), + (aggregator1, aggregator2) => aggregator1.merge(aggregator2)) + + mean = summary.mean.toBreeze + factor = summary.variance.toBreeze + require(mean.length == factor.length) + + var i = 0 + while (i < factor.length) { + factor(i) = if (factor(i) != 0.0) 1.0 / math.sqrt(factor(i)) else 0.0 + i += 1 + } + + this + } + + /** + * Applies standardization transformation on a vector. + * + * @param vector Vector to be standardized. + * @return Standardized vector. If the variance of a column is zero, it will return default `0.0` + * for the column with zero variance. + */ + override def transform(vector: Vector): Vector = { + if (mean == null || factor == null) { + throw new IllegalStateException( + "Haven't learned column summary statistics yet. Call fit first.") + } + + require(vector.size == mean.length) + + if (withMean) { + vector.toBreeze match { + case dv: BDV[Double] => + val output = vector.toBreeze.copy + var i = 0 + while (i < output.length) { + output(i) = (output(i) - mean(i)) * (if (withStd) factor(i) else 1.0) + i += 1 + } + Vectors.fromBreeze(output) + case v => throw new IllegalArgumentException("Do not support vector type " + v.getClass) + } + } else if (withStd) { + vector.toBreeze match { + case dv: BDV[Double] => Vectors.fromBreeze(dv :* factor) + case sv: BSV[Double] => + // For sparse vector, the `index` array inside sparse vector object will not be changed, + // so we can re-use it to save memory. + val output = new BSV[Double](sv.index, sv.data.clone(), sv.length) + var i = 0 + while (i < output.data.length) { + output.data(i) *= factor(output.index(i)) + i += 1 + } + Vectors.fromBreeze(output) + case v => throw new IllegalArgumentException("Do not support vector type " + v.getClass) + } + } else { + // Note that it's safe since we always assume that the data in RDD should be immutable. + vector + } + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/VectorTransformer.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/VectorTransformer.scala new file mode 100644 index 0000000000000..415a845332d45 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/VectorTransformer.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.feature + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.rdd.RDD + +/** + * :: DeveloperApi :: + * Trait for transformation of a vector + */ +@DeveloperApi +trait VectorTransformer extends Serializable { + + /** + * Applies transformation on a vector. + * + * @param vector vector to be transformed. + * @return transformed vector. + */ + def transform(vector: Vector): Vector + + /** + * Applies transformation on an RDD[Vector]. + * + * @param data RDD[Vector] to be transformed. + * @return transformed RDD[Vector]. + */ + def transform(data: RDD[Vector]): RDD[Vector] = { + // Later in #1498 , all RDD objects are sent via broadcasting instead of akka. + // So it should be no longer necessary to explicitly broadcast `this` object. + data.map(x => this.transform(x)) + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala index 58c1322757a43..45486b2c7d82d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala @@ -19,7 +19,7 @@ package org.apache.spark.mllib.linalg.distributed import java.util.Arrays -import breeze.linalg.{Vector => BV, DenseMatrix => BDM, DenseVector => BDV, SparseVector => BSV} +import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, SparseVector => BSV} import breeze.linalg.{svd => brzSvd, axpy => brzAxpy} import breeze.numerics.{sqrt => brzSqrt} import com.github.fommil.netlib.BLAS.{getInstance => blas} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/NormalizerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/NormalizerSuite.scala new file mode 100644 index 0000000000000..fb76dccfdf79e --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/NormalizerSuite.scala @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.feature + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors} +import org.apache.spark.mllib.util.LocalSparkContext +import org.apache.spark.mllib.util.TestingUtils._ + +class NormalizerSuite extends FunSuite with LocalSparkContext { + + val data = Array( + Vectors.sparse(3, Seq((0, -2.0), (1, 2.3))), + Vectors.dense(0.0, 0.0, 0.0), + Vectors.dense(0.6, -1.1, -3.0), + Vectors.sparse(3, Seq((1, 0.91), (2, 3.2))), + Vectors.sparse(3, Seq((0, 5.7), (1, 0.72), (2, 2.7))), + Vectors.sparse(3, Seq()) + ) + + lazy val dataRDD = sc.parallelize(data, 3) + + test("Normalization using L1 distance") { + val l1Normalizer = new Normalizer(1) + + val data1 = data.map(l1Normalizer.transform) + val data1RDD = l1Normalizer.transform(dataRDD) + + assert((data, data1, data1RDD.collect()).zipped.forall { + case (v1: DenseVector, v2: DenseVector, v3: DenseVector) => true + case (v1: SparseVector, v2: SparseVector, v3: SparseVector) => true + case _ => false + }, "The vector type should be preserved after normalization.") + + assert((data1, data1RDD.collect()).zipped.forall((v1, v2) => v1 ~== v2 absTol 1E-5)) + + assert(data1(0).toBreeze.norm(1) ~== 1.0 absTol 1E-5) + assert(data1(2).toBreeze.norm(1) ~== 1.0 absTol 1E-5) + assert(data1(3).toBreeze.norm(1) ~== 1.0 absTol 1E-5) + assert(data1(4).toBreeze.norm(1) ~== 1.0 absTol 1E-5) + + assert(data1(0) ~== Vectors.sparse(3, Seq((0, -0.465116279), (1, 0.53488372))) absTol 1E-5) + assert(data1(1) ~== Vectors.dense(0.0, 0.0, 0.0) absTol 1E-5) + assert(data1(2) ~== Vectors.dense(0.12765957, -0.23404255, -0.63829787) absTol 1E-5) + assert(data1(3) ~== Vectors.sparse(3, Seq((1, 0.22141119), (2, 0.7785888))) absTol 1E-5) + assert(data1(4) ~== Vectors.dense(0.625, 0.07894737, 0.29605263) absTol 1E-5) + assert(data1(5) ~== Vectors.sparse(3, Seq()) absTol 1E-5) + } + + test("Normalization using L2 distance") { + val l2Normalizer = new Normalizer() + + val data2 = data.map(l2Normalizer.transform) + val data2RDD = l2Normalizer.transform(dataRDD) + + assert((data, data2, data2RDD.collect()).zipped.forall { + case (v1: DenseVector, v2: DenseVector, v3: DenseVector) => true + case (v1: SparseVector, v2: SparseVector, v3: SparseVector) => true + case _ => false + }, "The vector type should be preserved after normalization.") + + assert((data2, data2RDD.collect()).zipped.forall((v1, v2) => v1 ~== v2 absTol 1E-5)) + + assert(data2(0).toBreeze.norm(2) ~== 1.0 absTol 1E-5) + assert(data2(2).toBreeze.norm(2) ~== 1.0 absTol 1E-5) + assert(data2(3).toBreeze.norm(2) ~== 1.0 absTol 1E-5) + assert(data2(4).toBreeze.norm(2) ~== 1.0 absTol 1E-5) + + assert(data2(0) ~== Vectors.sparse(3, Seq((0, -0.65617871), (1, 0.75460552))) absTol 1E-5) + assert(data2(1) ~== Vectors.dense(0.0, 0.0, 0.0) absTol 1E-5) + assert(data2(2) ~== Vectors.dense(0.184549876, -0.3383414, -0.922749378) absTol 1E-5) + assert(data2(3) ~== Vectors.sparse(3, Seq((1, 0.27352993), (2, 0.96186349))) absTol 1E-5) + assert(data2(4) ~== Vectors.dense(0.897906166, 0.113419726, 0.42532397) absTol 1E-5) + assert(data2(5) ~== Vectors.sparse(3, Seq()) absTol 1E-5) + } + + test("Normalization using L^Inf distance.") { + val lInfNormalizer = new Normalizer(Double.PositiveInfinity) + + val dataInf = data.map(lInfNormalizer.transform) + val dataInfRDD = lInfNormalizer.transform(dataRDD) + + assert((data, dataInf, dataInfRDD.collect()).zipped.forall { + case (v1: DenseVector, v2: DenseVector, v3: DenseVector) => true + case (v1: SparseVector, v2: SparseVector, v3: SparseVector) => true + case _ => false + }, "The vector type should be preserved after normalization.") + + assert((dataInf, dataInfRDD.collect()).zipped.forall((v1, v2) => v1 ~== v2 absTol 1E-5)) + + assert(dataInf(0).toArray.map(Math.abs).max ~== 1.0 absTol 1E-5) + assert(dataInf(2).toArray.map(Math.abs).max ~== 1.0 absTol 1E-5) + assert(dataInf(3).toArray.map(Math.abs).max ~== 1.0 absTol 1E-5) + assert(dataInf(4).toArray.map(Math.abs).max ~== 1.0 absTol 1E-5) + + assert(dataInf(0) ~== Vectors.sparse(3, Seq((0, -0.86956522), (1, 1.0))) absTol 1E-5) + assert(dataInf(1) ~== Vectors.dense(0.0, 0.0, 0.0) absTol 1E-5) + assert(dataInf(2) ~== Vectors.dense(0.2, -0.36666667, -1.0) absTol 1E-5) + assert(dataInf(3) ~== Vectors.sparse(3, Seq((1, 0.284375), (2, 1.0))) absTol 1E-5) + assert(dataInf(4) ~== Vectors.dense(1.0, 0.12631579, 0.473684211) absTol 1E-5) + assert(dataInf(5) ~== Vectors.sparse(3, Seq()) absTol 1E-5) + } + +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala new file mode 100644 index 0000000000000..5a9be923a8625 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala @@ -0,0 +1,200 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.feature + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} +import org.apache.spark.mllib.util.LocalSparkContext +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.mllib.rdd.RDDFunctions._ +import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, MultivariateOnlineSummarizer} +import org.apache.spark.rdd.RDD + +class StandardScalerSuite extends FunSuite with LocalSparkContext { + + private def computeSummary(data: RDD[Vector]): MultivariateStatisticalSummary = { + data.treeAggregate(new MultivariateOnlineSummarizer)( + (aggregator, data) => aggregator.add(data), + (aggregator1, aggregator2) => aggregator1.merge(aggregator2)) + } + + test("Standardization with dense input") { + val data = Array( + Vectors.dense(-2.0, 2.3, 0), + Vectors.dense(0.0, -1.0, -3.0), + Vectors.dense(0.0, -5.1, 0.0), + Vectors.dense(3.8, 0.0, 1.9), + Vectors.dense(1.7, -0.6, 0.0), + Vectors.dense(0.0, 1.9, 0.0) + ) + + val dataRDD = sc.parallelize(data, 3) + + val standardizer1 = new StandardScaler(withMean = true, withStd = true) + val standardizer2 = new StandardScaler() + val standardizer3 = new StandardScaler(withMean = true, withStd = false) + + withClue("Using a standardizer before fitting the model should throw exception.") { + intercept[IllegalStateException] { + data.map(standardizer1.transform) + } + } + + standardizer1.fit(dataRDD) + standardizer2.fit(dataRDD) + standardizer3.fit(dataRDD) + + val data1 = data.map(standardizer1.transform) + val data2 = data.map(standardizer2.transform) + val data3 = data.map(standardizer3.transform) + + val data1RDD = standardizer1.transform(dataRDD) + val data2RDD = standardizer2.transform(dataRDD) + val data3RDD = standardizer3.transform(dataRDD) + + val summary = computeSummary(dataRDD) + val summary1 = computeSummary(data1RDD) + val summary2 = computeSummary(data2RDD) + val summary3 = computeSummary(data3RDD) + + assert((data, data1, data1RDD.collect()).zipped.forall { + case (v1: DenseVector, v2: DenseVector, v3: DenseVector) => true + case (v1: SparseVector, v2: SparseVector, v3: SparseVector) => true + case _ => false + }, "The vector type should be preserved after standardization.") + + assert((data, data2, data2RDD.collect()).zipped.forall { + case (v1: DenseVector, v2: DenseVector, v3: DenseVector) => true + case (v1: SparseVector, v2: SparseVector, v3: SparseVector) => true + case _ => false + }, "The vector type should be preserved after standardization.") + + assert((data, data3, data3RDD.collect()).zipped.forall { + case (v1: DenseVector, v2: DenseVector, v3: DenseVector) => true + case (v1: SparseVector, v2: SparseVector, v3: SparseVector) => true + case _ => false + }, "The vector type should be preserved after standardization.") + + assert((data1, data1RDD.collect()).zipped.forall((v1, v2) => v1 ~== v2 absTol 1E-5)) + assert((data2, data2RDD.collect()).zipped.forall((v1, v2) => v1 ~== v2 absTol 1E-5)) + assert((data3, data3RDD.collect()).zipped.forall((v1, v2) => v1 ~== v2 absTol 1E-5)) + + assert(summary1.mean ~== Vectors.dense(0.0, 0.0, 0.0) absTol 1E-5) + assert(summary1.variance ~== Vectors.dense(1.0, 1.0, 1.0) absTol 1E-5) + + assert(summary2.mean !~== Vectors.dense(0.0, 0.0, 0.0) absTol 1E-5) + assert(summary2.variance ~== Vectors.dense(1.0, 1.0, 1.0) absTol 1E-5) + + assert(summary3.mean ~== Vectors.dense(0.0, 0.0, 0.0) absTol 1E-5) + assert(summary3.variance ~== summary.variance absTol 1E-5) + + assert(data1(0) ~== Vectors.dense(-1.31527964, 1.023470449, 0.11637768424) absTol 1E-5) + assert(data1(3) ~== Vectors.dense(1.637735298, 0.156973995, 1.32247368462) absTol 1E-5) + assert(data2(4) ~== Vectors.dense(0.865538862, -0.22604255, 0.0) absTol 1E-5) + assert(data2(5) ~== Vectors.dense(0.0, 0.71580142, 0.0) absTol 1E-5) + assert(data3(1) ~== Vectors.dense(-0.58333333, -0.58333333, -2.8166666666) absTol 1E-5) + assert(data3(5) ~== Vectors.dense(-0.58333333, 2.316666666, 0.18333333333) absTol 1E-5) + } + + + test("Standardization with sparse input") { + val data = Array( + Vectors.sparse(3, Seq((0, -2.0), (1, 2.3))), + Vectors.sparse(3, Seq((1, -1.0), (2, -3.0))), + Vectors.sparse(3, Seq((1, -5.1))), + Vectors.sparse(3, Seq((0, 3.8), (2, 1.9))), + Vectors.sparse(3, Seq((0, 1.7), (1, -0.6))), + Vectors.sparse(3, Seq((1, 1.9))) + ) + + val dataRDD = sc.parallelize(data, 3) + + val standardizer1 = new StandardScaler(withMean = true, withStd = true) + val standardizer2 = new StandardScaler() + val standardizer3 = new StandardScaler(withMean = true, withStd = false) + + standardizer1.fit(dataRDD) + standardizer2.fit(dataRDD) + standardizer3.fit(dataRDD) + + val data2 = data.map(standardizer2.transform) + + withClue("Standardization with mean can not be applied on sparse input.") { + intercept[IllegalArgumentException] { + data.map(standardizer1.transform) + } + } + + withClue("Standardization with mean can not be applied on sparse input.") { + intercept[IllegalArgumentException] { + data.map(standardizer3.transform) + } + } + + val data2RDD = standardizer2.transform(dataRDD) + + val summary2 = computeSummary(data2RDD) + + assert((data, data2, data2RDD.collect()).zipped.forall { + case (v1: DenseVector, v2: DenseVector, v3: DenseVector) => true + case (v1: SparseVector, v2: SparseVector, v3: SparseVector) => true + case _ => false + }, "The vector type should be preserved after standardization.") + + assert((data2, data2RDD.collect()).zipped.forall((v1, v2) => v1 ~== v2 absTol 1E-5)) + + assert(summary2.mean !~== Vectors.dense(0.0, 0.0, 0.0) absTol 1E-5) + assert(summary2.variance ~== Vectors.dense(1.0, 1.0, 1.0) absTol 1E-5) + + assert(data2(4) ~== Vectors.sparse(3, Seq((0, 0.865538862), (1, -0.22604255))) absTol 1E-5) + assert(data2(5) ~== Vectors.sparse(3, Seq((1, 0.71580142))) absTol 1E-5) + } + + test("Standardization with constant input") { + // When the input data is all constant, the variance is zero. The standardization against + // zero variance is not well-defined, but we decide to just set it into zero here. + val data = Array( + Vectors.dense(2.0), + Vectors.dense(2.0), + Vectors.dense(2.0) + ) + + val dataRDD = sc.parallelize(data, 2) + + val standardizer1 = new StandardScaler(withMean = true, withStd = true) + val standardizer2 = new StandardScaler(withMean = true, withStd = false) + val standardizer3 = new StandardScaler(withMean = false, withStd = true) + + standardizer1.fit(dataRDD) + standardizer2.fit(dataRDD) + standardizer3.fit(dataRDD) + + val data1 = data.map(standardizer1.transform) + val data2 = data.map(standardizer2.transform) + val data3 = data.map(standardizer3.transform) + + assert(data1.forall(_.toArray.forall(_ == 0.0)), + "The variance is zero, so the transformed result should be 0.0") + assert(data2.forall(_.toArray.forall(_ == 0.0)), + "The variance is zero, so the transformed result should be 0.0") + assert(data3.forall(_.toArray.forall(_ == 0.0)), + "The variance is zero, so the transformed result should be 0.0") + } + +} From 3823f6d25e2a89ca1bfa62a76f6e708c2c63f064 Mon Sep 17 00:00:00 2001 From: Liquan Pei Date: Sun, 3 Aug 2014 23:55:58 -0700 Subject: [PATCH 024/231] [MLlib] [SPARK-2510]Word2Vec: Distributed Representation of Words This is a pull request regarding SPARK-2510 at https://issues.apache.org/jira/browse/SPARK-2510. Word2Vec creates vector representation of words in a text corpus. The algorithm first constructs a vocabulary from the corpus and then learns vector representation of words in the vocabulary. The vector representation can be used as features in natural language processing and machine learning algorithms. To make our implementation more scalable, we train each partition separately and merge the model of each partition after each iteration. To make the model more accurate, multiple iterations may be needed. To investigate the vector representations is to find the closest words for a query word. For example, the top 20 closest words to "china" are for 1 partition and 1 iteration : taiwan 0.8077646146334014 korea 0.740913304563621 japan 0.7240667798885471 republic 0.7107151279078352 thailand 0.6953217332072862 tibet 0.6916782118129544 mongolia 0.6800858715972612 macau 0.6794925677480378 singapore 0.6594048695593799 manchuria 0.658989931844148 laos 0.6512978726001666 nepal 0.6380792327845325 mainland 0.6365469459587788 myanmar 0.6358614338840394 macedonia 0.6322366180313249 xinjiang 0.6285291551708028 russia 0.6279951236068411 india 0.6272874944023487 shanghai 0.6234544135576999 macao 0.6220588462925876 The result with 10 partitions and 5 iterations is: taiwan 0.8310495079388313 india 0.7737171315919039 japan 0.756777901233668 korea 0.7429767187102452 indonesia 0.7407557427278356 pakistan 0.712883426985585 mainland 0.7053379963140822 thailand 0.696298191073948 mongolia 0.693690656871415 laos 0.6913069680735292 macau 0.6903427690029617 republic 0.6766381604813666 malaysia 0.676460699141784 singapore 0.6728790997360923 malaya 0.672345232966194 manchuria 0.6703732292753156 macedonia 0.6637955686322028 myanmar 0.6589462882439646 kazakhstan 0.657017801081494 cambodia 0.6542383836451932 Author: Liquan Pei Author: Xiangrui Meng Author: Liquan Pei Closes #1719 from Ishiihara/master and squashes the following commits: 2ba9483 [Liquan Pei] minor fix for Word2Vec test e248441 [Liquan Pei] minor style change 26a948d [Liquan Pei] Merge pull request #1 from mengxr/Ishiihara-master c14da41 [Xiangrui Meng] fix styles 384c771 [Xiangrui Meng] remove minCount and window from constructor change model to use float instead of double e93e726 [Liquan Pei] use treeAggregate instead of aggregate 1a8fb41 [Liquan Pei] use weighted sum in combOp 7efbb6f [Liquan Pei] use broadcast version of vocab in aggregate 6bcc8be [Liquan Pei] add multiple iteration support 720b5a3 [Liquan Pei] Add test for Word2Vec algorithm, minor fixes 2e92b59 [Liquan Pei] modify according to feedback 57dc50d [Liquan Pei] code formatting e4a04d3 [Liquan Pei] minor fix 0aafb1b [Liquan Pei] Add comments, minor fixes 8d6befe [Liquan Pei] initial commit (cherry picked from commit e053c55819363fab7068bb9165e3379f0c2f570c) Signed-off-by: Xiangrui Meng --- .../apache/spark/mllib/feature/Word2Vec.scala | 424 ++++++++++++++++++ .../spark/mllib/feature/Word2VecSuite.scala | 61 +++ 2 files changed, 485 insertions(+) create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala create mode 100644 mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala new file mode 100644 index 0000000000000..87c81e7b0bd2f --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -0,0 +1,424 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.feature + +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer +import scala.util.Random + +import com.github.fommil.netlib.BLAS.{getInstance => blas} +import org.apache.spark.{HashPartitioner, Logging} +import org.apache.spark.SparkContext._ +import org.apache.spark.annotation.Experimental +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.rdd.RDDFunctions._ +import org.apache.spark.rdd._ +import org.apache.spark.storage.StorageLevel + +/** + * Entry in vocabulary + */ +private case class VocabWord( + var word: String, + var cn: Int, + var point: Array[Int], + var code: Array[Int], + var codeLen:Int +) + +/** + * :: Experimental :: + * Word2Vec creates vector representation of words in a text corpus. + * The algorithm first constructs a vocabulary from the corpus + * and then learns vector representation of words in the vocabulary. + * The vector representation can be used as features in + * natural language processing and machine learning algorithms. + * + * We used skip-gram model in our implementation and hierarchical softmax + * method to train the model. The variable names in the implementation + * matches the original C implementation. + * + * For original C implementation, see https://code.google.com/p/word2vec/ + * For research papers, see + * Efficient Estimation of Word Representations in Vector Space + * and + * Distributed Representations of Words and Phrases and their Compositionality. + * @param size vector dimension + * @param startingAlpha initial learning rate + * @param parallelism number of partitions to run Word2Vec (using a small number for accuracy) + * @param numIterations number of iterations to run, should be smaller than or equal to parallelism + */ +@Experimental +class Word2Vec( + val size: Int, + val startingAlpha: Double, + val parallelism: Int, + val numIterations: Int) extends Serializable with Logging { + + /** + * Word2Vec with a single thread. + */ + def this(size: Int, startingAlpha: Int) = this(size, startingAlpha, 1, 1) + + private val EXP_TABLE_SIZE = 1000 + private val MAX_EXP = 6 + private val MAX_CODE_LENGTH = 40 + private val MAX_SENTENCE_LENGTH = 1000 + private val layer1Size = size + private val modelPartitionNum = 100 + + /** context words from [-window, window] */ + private val window = 5 + + /** minimum frequency to consider a vocabulary word */ + private val minCount = 5 + + private var trainWordsCount = 0 + private var vocabSize = 0 + private var vocab: Array[VocabWord] = null + private var vocabHash = mutable.HashMap.empty[String, Int] + private var alpha = startingAlpha + + private def learnVocab(words:RDD[String]): Unit = { + vocab = words.map(w => (w, 1)) + .reduceByKey(_ + _) + .map(x => VocabWord( + x._1, + x._2, + new Array[Int](MAX_CODE_LENGTH), + new Array[Int](MAX_CODE_LENGTH), + 0)) + .filter(_.cn >= minCount) + .collect() + .sortWith((a, b) => a.cn > b.cn) + + vocabSize = vocab.length + var a = 0 + while (a < vocabSize) { + vocabHash += vocab(a).word -> a + trainWordsCount += vocab(a).cn + a += 1 + } + logInfo("trainWordsCount = " + trainWordsCount) + } + + private def createExpTable(): Array[Float] = { + val expTable = new Array[Float](EXP_TABLE_SIZE) + var i = 0 + while (i < EXP_TABLE_SIZE) { + val tmp = math.exp((2.0 * i / EXP_TABLE_SIZE - 1.0) * MAX_EXP) + expTable(i) = (tmp / (tmp + 1.0)).toFloat + i += 1 + } + expTable + } + + private def createBinaryTree(): Unit = { + val count = new Array[Long](vocabSize * 2 + 1) + val binary = new Array[Int](vocabSize * 2 + 1) + val parentNode = new Array[Int](vocabSize * 2 + 1) + val code = new Array[Int](MAX_CODE_LENGTH) + val point = new Array[Int](MAX_CODE_LENGTH) + var a = 0 + while (a < vocabSize) { + count(a) = vocab(a).cn + a += 1 + } + while (a < 2 * vocabSize) { + count(a) = 1e9.toInt + a += 1 + } + var pos1 = vocabSize - 1 + var pos2 = vocabSize + + var min1i = 0 + var min2i = 0 + + a = 0 + while (a < vocabSize - 1) { + if (pos1 >= 0) { + if (count(pos1) < count(pos2)) { + min1i = pos1 + pos1 -= 1 + } else { + min1i = pos2 + pos2 += 1 + } + } else { + min1i = pos2 + pos2 += 1 + } + if (pos1 >= 0) { + if (count(pos1) < count(pos2)) { + min2i = pos1 + pos1 -= 1 + } else { + min2i = pos2 + pos2 += 1 + } + } else { + min2i = pos2 + pos2 += 1 + } + count(vocabSize + a) = count(min1i) + count(min2i) + parentNode(min1i) = vocabSize + a + parentNode(min2i) = vocabSize + a + binary(min2i) = 1 + a += 1 + } + // Now assign binary code to each vocabulary word + var i = 0 + a = 0 + while (a < vocabSize) { + var b = a + i = 0 + while (b != vocabSize * 2 - 2) { + code(i) = binary(b) + point(i) = b + i += 1 + b = parentNode(b) + } + vocab(a).codeLen = i + vocab(a).point(0) = vocabSize - 2 + b = 0 + while (b < i) { + vocab(a).code(i - b - 1) = code(b) + vocab(a).point(i - b) = point(b) - vocabSize + b += 1 + } + a += 1 + } + } + + /** + * Computes the vector representation of each word in vocabulary. + * @param dataset an RDD of words + * @return a Word2VecModel + */ + def fit[S <: Iterable[String]](dataset: RDD[S]): Word2VecModel = { + + val words = dataset.flatMap(x => x) + + learnVocab(words) + + createBinaryTree() + + val sc = dataset.context + + val expTable = sc.broadcast(createExpTable()) + val bcVocab = sc.broadcast(vocab) + val bcVocabHash = sc.broadcast(vocabHash) + + val sentences: RDD[Array[Int]] = words.mapPartitions { iter => + new Iterator[Array[Int]] { + def hasNext: Boolean = iter.hasNext + + def next(): Array[Int] = { + var sentence = new ArrayBuffer[Int] + var sentenceLength = 0 + while (iter.hasNext && sentenceLength < MAX_SENTENCE_LENGTH) { + val word = bcVocabHash.value.get(iter.next()) + word match { + case Some(w) => + sentence += w + sentenceLength += 1 + case None => + } + } + sentence.toArray + } + } + } + + val newSentences = sentences.repartition(parallelism).cache() + var syn0Global = + Array.fill[Float](vocabSize * layer1Size)((Random.nextFloat() - 0.5f) / layer1Size) + var syn1Global = new Array[Float](vocabSize * layer1Size) + + for(iter <- 1 to numIterations) { + val (aggSyn0, aggSyn1, _, _) = + // TODO: broadcast temp instead of serializing it directly + // or initialize the model in each executor + newSentences.treeAggregate((syn0Global, syn1Global, 0, 0))( + seqOp = (c, v) => (c, v) match { + case ((syn0, syn1, lastWordCount, wordCount), sentence) => + var lwc = lastWordCount + var wc = wordCount + if (wordCount - lastWordCount > 10000) { + lwc = wordCount + alpha = startingAlpha * (1 - parallelism * wordCount.toDouble / (trainWordsCount + 1)) + if (alpha < startingAlpha * 0.0001) alpha = startingAlpha * 0.0001 + logInfo("wordCount = " + wordCount + ", alpha = " + alpha) + } + wc += sentence.size + var pos = 0 + while (pos < sentence.size) { + val word = sentence(pos) + // TODO: fix random seed + val b = Random.nextInt(window) + // Train Skip-gram + var a = b + while (a < window * 2 + 1 - b) { + if (a != window) { + val c = pos - window + a + if (c >= 0 && c < sentence.size) { + val lastWord = sentence(c) + val l1 = lastWord * layer1Size + val neu1e = new Array[Float](layer1Size) + // Hierarchical softmax + var d = 0 + while (d < bcVocab.value(word).codeLen) { + val l2 = bcVocab.value(word).point(d) * layer1Size + // Propagate hidden -> output + var f = blas.sdot(layer1Size, syn0, l1, 1, syn1, l2, 1) + if (f > -MAX_EXP && f < MAX_EXP) { + val ind = ((f + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2.0)).toInt + f = expTable.value(ind) + val g = ((1 - bcVocab.value(word).code(d) - f) * alpha).toFloat + blas.saxpy(layer1Size, g, syn1, l2, 1, neu1e, 0, 1) + blas.saxpy(layer1Size, g, syn0, l1, 1, syn1, l2, 1) + } + d += 1 + } + blas.saxpy(layer1Size, 1.0f, neu1e, 0, 1, syn0, l1, 1) + } + } + a += 1 + } + pos += 1 + } + (syn0, syn1, lwc, wc) + }, + combOp = (c1, c2) => (c1, c2) match { + case ((syn0_1, syn1_1, lwc_1, wc_1), (syn0_2, syn1_2, lwc_2, wc_2)) => + val n = syn0_1.length + val weight1 = 1.0f * wc_1 / (wc_1 + wc_2) + val weight2 = 1.0f * wc_2 / (wc_1 + wc_2) + blas.sscal(n, weight1, syn0_1, 1) + blas.sscal(n, weight1, syn1_1, 1) + blas.saxpy(n, weight2, syn0_2, 1, syn0_1, 1) + blas.saxpy(n, weight2, syn1_2, 1, syn1_1, 1) + (syn0_1, syn1_1, lwc_1 + lwc_2, wc_1 + wc_2) + }) + syn0Global = aggSyn0 + syn1Global = aggSyn1 + } + newSentences.unpersist() + + val wordMap = new Array[(String, Array[Float])](vocabSize) + var i = 0 + while (i < vocabSize) { + val word = bcVocab.value(i).word + val vector = new Array[Float](layer1Size) + Array.copy(syn0Global, i * layer1Size, vector, 0, layer1Size) + wordMap(i) = (word, vector) + i += 1 + } + val modelRDD = sc.parallelize(wordMap, modelPartitionNum) + .partitionBy(new HashPartitioner(modelPartitionNum)) + .persist(StorageLevel.MEMORY_AND_DISK) + + new Word2VecModel(modelRDD) + } +} + +/** +* Word2Vec model +*/ +class Word2VecModel(private val model: RDD[(String, Array[Float])]) extends Serializable { + + private def cosineSimilarity(v1: Array[Float], v2: Array[Float]): Double = { + require(v1.length == v2.length, "Vectors should have the same length") + val n = v1.length + val norm1 = blas.snrm2(n, v1, 1) + val norm2 = blas.snrm2(n, v2, 1) + if (norm1 == 0 || norm2 == 0) return 0.0 + blas.sdot(n, v1, 1, v2,1) / norm1 / norm2 + } + + /** + * Transforms a word to its vector representation + * @param word a word + * @return vector representation of word + */ + def transform(word: String): Vector = { + val result = model.lookup(word) + if (result.isEmpty) { + throw new IllegalStateException(s"$word not in vocabulary") + } + else Vectors.dense(result(0).map(_.toDouble)) + } + + /** + * Transforms an RDD to its vector representation + * @param dataset a an RDD of words + * @return RDD of vector representation + */ + def transform(dataset: RDD[String]): RDD[Vector] = { + dataset.map(word => transform(word)) + } + + /** + * Find synonyms of a word + * @param word a word + * @param num number of synonyms to find + * @return array of (word, similarity) + */ + def findSynonyms(word: String, num: Int): Array[(String, Double)] = { + val vector = transform(word) + findSynonyms(vector,num) + } + + /** + * Find synonyms of the vector representation of a word + * @param vector vector representation of a word + * @param num number of synonyms to find + * @return array of (word, cosineSimilarity) + */ + def findSynonyms(vector: Vector, num: Int): Array[(String, Double)] = { + require(num > 0, "Number of similar words should > 0") + val topK = model.map { case(w, vec) => + (cosineSimilarity(vector.toArray.map(_.toFloat), vec), w) } + .sortByKey(ascending = false) + .take(num + 1) + .map(_.swap) + .tail + + topK + } +} + +object Word2Vec{ + /** + * Train Word2Vec model + * @param input RDD of words + * @param size vector dimension + * @param startingAlpha initial learning rate + * @param parallelism number of partitions to run Word2Vec (using a small number for accuracy) + * @param numIterations number of iterations, should be smaller than or equal to parallelism + * @return Word2Vec model + */ + def train[S <: Iterable[String]]( + input: RDD[S], + size: Int, + startingAlpha: Double, + parallelism: Int = 1, + numIterations:Int = 1): Word2VecModel = { + new Word2Vec(size,startingAlpha, parallelism, numIterations).fit[S](input) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala new file mode 100644 index 0000000000000..b5db39b68a223 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.feature + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.util.LocalSparkContext + +class Word2VecSuite extends FunSuite with LocalSparkContext { + + // TODO: add more tests + + test("Word2Vec") { + val sentence = "a b " * 100 + "a c " * 10 + val localDoc = Seq(sentence, sentence) + val doc = sc.parallelize(localDoc) + .map(line => line.split(" ").toSeq) + val size = 10 + val startingAlpha = 0.025 + val window = 2 + val minCount = 2 + val num = 2 + + val model = Word2Vec.train(doc, size, startingAlpha) + val syms = model.findSynonyms("a", 2) + assert(syms.length == num) + assert(syms(0)._1 == "b") + assert(syms(1)._1 == "c") + } + + + test("Word2VecModel") { + val num = 2 + val localModel = Seq( + ("china", Array(0.50f, 0.50f, 0.50f, 0.50f)), + ("japan", Array(0.40f, 0.50f, 0.50f, 0.50f)), + ("taiwan", Array(0.60f, 0.50f, 0.50f, 0.50f)), + ("korea", Array(0.45f, 0.60f, 0.60f, 0.60f)) + ) + val model = new Word2VecModel(sc.parallelize(localModel, 2)) + val syms = model.findSynonyms("china", num) + assert(syms.length == num) + assert(syms(0)._1 == "taiwan") + assert(syms(1)._1 == "japan") + } +} From bfd2f39581d958d5aafaa76994f44213bcdfbb69 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 4 Aug 2014 12:13:41 -0700 Subject: [PATCH 025/231] [SPARK-1687] [PySpark] pickable namedtuple Add an hook to replace original namedtuple with an pickable one, then namedtuple could be used in RDDs. PS: pyspark should be import BEFORE "from collections import namedtuple" Author: Davies Liu Closes #1623 from davies/namedtuple and squashes the following commits: 045dad8 [Davies Liu] remove unrelated code changes 4132f32 [Davies Liu] address comment 55b1c1a [Davies Liu] fix tests 61f86eb [Davies Liu] replace all the reference of namedtuple to new hacked one 98df6c6 [Davies Liu] Merge branch 'master' of github.com:apache/spark into namedtuple f7b1bde [Davies Liu] add hack for CloudPickleSerializer 0c5c849 [Davies Liu] Merge branch 'master' of github.com:apache/spark into namedtuple 21991e6 [Davies Liu] hack namedtuple in __main__ module, make it picklable. 93b03b8 [Davies Liu] pickable namedtuple (cherry picked from commit 59f84a9531f7974a053fd4963ce9afd88273ea4c) Signed-off-by: Josh Rosen --- python/pyspark/serializers.py | 60 +++++++++++++++++++++++++++++++++++ python/pyspark/tests.py | 19 +++++++++++ 2 files changed, 79 insertions(+) diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 03b31ae9624c2..1b52c144df087 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -65,6 +65,9 @@ import marshal import struct import sys +import types +import collections + from pyspark import cloudpickle @@ -267,6 +270,63 @@ def dumps(self, obj): return obj +# Hook namedtuple, make it picklable + +__cls = {} + + +def _restore(name, fields, value): + """ Restore an object of namedtuple""" + k = (name, fields) + cls = __cls.get(k) + if cls is None: + cls = collections.namedtuple(name, fields) + __cls[k] = cls + return cls(*value) + + +def _hack_namedtuple(cls): + """ Make class generated by namedtuple picklable """ + name = cls.__name__ + fields = cls._fields + def __reduce__(self): + return (_restore, (name, fields, tuple(self))) + cls.__reduce__ = __reduce__ + return cls + + +def _hijack_namedtuple(): + """ Hack namedtuple() to make it picklable """ + global _old_namedtuple # or it will put in closure + + def _copy_func(f): + return types.FunctionType(f.func_code, f.func_globals, f.func_name, + f.func_defaults, f.func_closure) + + _old_namedtuple = _copy_func(collections.namedtuple) + + def namedtuple(name, fields, verbose=False, rename=False): + cls = _old_namedtuple(name, fields, verbose, rename) + return _hack_namedtuple(cls) + + # replace namedtuple with new one + collections.namedtuple.func_globals["_old_namedtuple"] = _old_namedtuple + collections.namedtuple.func_globals["_hack_namedtuple"] = _hack_namedtuple + collections.namedtuple.func_code = namedtuple.func_code + + # hack the cls already generated by namedtuple + # those created in other module can be pickled as normal, + # so only hack those in __main__ module + for n, o in sys.modules["__main__"].__dict__.iteritems(): + if (type(o) is type and o.__base__ is tuple + and hasattr(o, "_fields") + and "__reduce__" not in o.__dict__): + _hack_namedtuple(o) # hack inplace + + +_hijack_namedtuple() + + class PickleSerializer(FramedSerializer): """ Serializes objects using Python's cPickle serializer: diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index acc3c30371621..4ac94ba729d35 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -112,6 +112,17 @@ def test_huge_dataset(self): m._cleanup() +class SerializationTestCase(unittest.TestCase): + + def test_namedtuple(self): + from collections import namedtuple + from cPickle import dumps, loads + P = namedtuple("P", "x y") + p1 = P(1, 3) + p2 = loads(dumps(p1, 2)) + self.assertEquals(p1, p2) + + class PySparkTestCase(unittest.TestCase): def setUp(self): @@ -298,6 +309,14 @@ def test_itemgetter(self): self.assertEqual([1], rdd.map(itemgetter(1)).collect()) self.assertEqual([(2, 3)], rdd.map(itemgetter(2, 3)).collect()) + def test_namedtuple_in_rdd(self): + from collections import namedtuple + Person = namedtuple("Person", "id firstName lastName") + jon = Person(1, "Jon", "Doe") + jane = Person(2, "Jane", "Doe") + theDoes = self.sc.parallelize([jon, jane]) + self.assertEquals([jon, jane], theDoes.collect()) + class TestIO(PySparkTestCase): From aa7a48ee905b95e57f64051ea887d4775b427603 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Mon, 4 Aug 2014 12:59:18 -0700 Subject: [PATCH 026/231] SPARK-2792. Fix reading too much or too little data from each stream in ExternalMap / Sorter All these changes are from mridulm's work in #1609, but extracted here to fix this specific issue and make it easier to merge not 1.1. This particular set of changes is to make sure that we read exactly the right range of bytes from each spill file in EAOM: some serializers can write bytes after the last object (e.g. the TC_RESET flag in Java serialization) and that would confuse the previous code into reading it as part of the next batch. There are also improvements to cleanup to make sure files are closed. In addition to bringing in the changes to ExternalAppendOnlyMap, I also copied them to the corresponding code in ExternalSorter and updated its test suite to test for the same issues. Author: Matei Zaharia Closes #1722 from mateiz/spark-2792 and squashes the following commits: 5d4bfb5 [Matei Zaharia] Make objectStreamReset counter count the last object written too 18fe865 [Matei Zaharia] Update docs on objectStreamReset 576ee83 [Matei Zaharia] Allow objectStreamReset to be 0 0374217 [Matei Zaharia] Remove super paranoid code to close file handles bda37bb [Matei Zaharia] Implement Mridul's ExternalAppendOnlyMap fixes in ExternalSorter too 0d6dad7 [Matei Zaharia] Added Mridul's test changes for ExternalAppendOnlyMap 9a78e4b [Matei Zaharia] Add @mridulm's fixes to ExternalAppendOnlyMap for batch sizes --- .../spark/serializer/JavaSerializer.scala | 5 +- .../collection/ExternalAppendOnlyMap.scala | 86 +++++++++++---- .../util/collection/ExternalSorter.scala | 104 +++++++++++++----- .../ExternalAppendOnlyMapSuite.scala | 33 ++++-- .../util/collection/ExternalSorterSuite.scala | 47 +++++--- docs/configuration.md | 2 +- 6 files changed, 194 insertions(+), 83 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala index a7fa057ee05f7..34bc3124097bb 100644 --- a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala @@ -35,16 +35,15 @@ private[spark] class JavaSerializationStream(out: OutputStream, counterReset: In /** * Calling reset to avoid memory leak: * http://stackoverflow.com/questions/1281549/memory-leak-traps-in-the-java-standard-api - * But only call it every 10,000th time to avoid bloated serialization streams (when + * But only call it every 100th time to avoid bloated serialization streams (when * the stream 'resets' object class descriptions have to be re-written) */ def writeObject[T: ClassTag](t: T): SerializationStream = { objOut.writeObject(t) + counter += 1 if (counterReset > 0 && counter >= counterReset) { objOut.reset() counter = 0 - } else { - counter += 1 } this } diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala index cb67a1c039f20..5d10a1f84493c 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala @@ -17,7 +17,7 @@ package org.apache.spark.util.collection -import java.io.{InputStream, BufferedInputStream, FileInputStream, File, Serializable, EOFException} +import java.io._ import java.util.Comparator import scala.collection.BufferedIterator @@ -28,7 +28,7 @@ import com.google.common.io.ByteStreams import org.apache.spark.{Logging, SparkEnv} import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.serializer.Serializer +import org.apache.spark.serializer.{DeserializationStream, Serializer} import org.apache.spark.storage.{BlockId, BlockManager} import org.apache.spark.util.collection.ExternalAppendOnlyMap.HashComparator @@ -199,13 +199,16 @@ class ExternalAppendOnlyMap[K, V, C]( // Flush the disk writer's contents to disk, and update relevant variables def flush() = { - writer.commitAndClose() - val bytesWritten = writer.bytesWritten + val w = writer + writer = null + w.commitAndClose() + val bytesWritten = w.bytesWritten batchSizes.append(bytesWritten) _diskBytesSpilled += bytesWritten objectsWritten = 0 } + var success = false try { val it = currentMap.destructiveSortedIterator(keyComparator) while (it.hasNext) { @@ -215,16 +218,28 @@ class ExternalAppendOnlyMap[K, V, C]( if (objectsWritten == serializerBatchSize) { flush() - writer.close() writer = blockManager.getDiskWriter(blockId, file, serializer, fileBufferSize) } } if (objectsWritten > 0) { flush() + } else if (writer != null) { + val w = writer + writer = null + w.revertPartialWritesAndClose() } + success = true } finally { - // Partial failures cannot be tolerated; do not revert partial writes - writer.close() + if (!success) { + // This code path only happens if an exception was thrown above before we set success; + // close our stuff and let the exception be thrown further + if (writer != null) { + writer.revertPartialWritesAndClose() + } + if (file.exists()) { + file.delete() + } + } } currentMap = new SizeTrackingAppendOnlyMap[K, C] @@ -389,27 +404,51 @@ class ExternalAppendOnlyMap[K, V, C]( * An iterator that returns (K, C) pairs in sorted order from an on-disk map */ private class DiskMapIterator(file: File, blockId: BlockId, batchSizes: ArrayBuffer[Long]) - extends Iterator[(K, C)] { - private val fileStream = new FileInputStream(file) - private val bufferedStream = new BufferedInputStream(fileStream, fileBufferSize) + extends Iterator[(K, C)] + { + private val batchOffsets = batchSizes.scanLeft(0L)(_ + _) // Size will be batchSize.length + 1 + assert(file.length() == batchOffsets(batchOffsets.length - 1)) + + private var batchIndex = 0 // Which batch we're in + private var fileStream: FileInputStream = null // An intermediate stream that reads from exactly one batch // This guards against pre-fetching and other arbitrary behavior of higher level streams - private var batchStream = nextBatchStream() - private var compressedStream = blockManager.wrapForCompression(blockId, batchStream) - private var deserializeStream = ser.deserializeStream(compressedStream) + private var deserializeStream = nextBatchStream() private var nextItem: (K, C) = null private var objectsRead = 0 /** * Construct a stream that reads only from the next batch. */ - private def nextBatchStream(): InputStream = { - if (batchSizes.length > 0) { - ByteStreams.limit(bufferedStream, batchSizes.remove(0)) + private def nextBatchStream(): DeserializationStream = { + // Note that batchOffsets.length = numBatches + 1 since we did a scan above; check whether + // we're still in a valid batch. + if (batchIndex < batchOffsets.length - 1) { + if (deserializeStream != null) { + deserializeStream.close() + fileStream.close() + deserializeStream = null + fileStream = null + } + + val start = batchOffsets(batchIndex) + fileStream = new FileInputStream(file) + fileStream.getChannel.position(start) + batchIndex += 1 + + val end = batchOffsets(batchIndex) + + assert(end >= start, "start = " + start + ", end = " + end + + ", batchOffsets = " + batchOffsets.mkString("[", ", ", "]")) + + val bufferedStream = new BufferedInputStream(ByteStreams.limit(fileStream, end - start)) + val compressedStream = blockManager.wrapForCompression(blockId, bufferedStream) + ser.deserializeStream(compressedStream) } else { // No more batches left - bufferedStream + cleanup() + null } } @@ -424,10 +463,8 @@ class ExternalAppendOnlyMap[K, V, C]( val item = deserializeStream.readObject().asInstanceOf[(K, C)] objectsRead += 1 if (objectsRead == serializerBatchSize) { - batchStream = nextBatchStream() - compressedStream = blockManager.wrapForCompression(blockId, batchStream) - deserializeStream = ser.deserializeStream(compressedStream) objectsRead = 0 + deserializeStream = nextBatchStream() } item } catch { @@ -439,6 +476,9 @@ class ExternalAppendOnlyMap[K, V, C]( override def hasNext: Boolean = { if (nextItem == null) { + if (deserializeStream == null) { + return false + } nextItem = readNextItem() } nextItem != null @@ -455,7 +495,11 @@ class ExternalAppendOnlyMap[K, V, C]( // TODO: Ensure this gets called even if the iterator isn't drained. private def cleanup() { - deserializeStream.close() + batchIndex = batchOffsets.length // Prevent reading any other batch + val ds = deserializeStream + deserializeStream = null + fileStream = null + ds.close() file.delete() } } diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 6e415a2bd8ce2..b04c50bd3e196 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -26,7 +26,7 @@ import scala.collection.mutable import com.google.common.io.ByteStreams import org.apache.spark.{Aggregator, SparkEnv, Logging, Partitioner} -import org.apache.spark.serializer.Serializer +import org.apache.spark.serializer.{DeserializationStream, Serializer} import org.apache.spark.storage.BlockId /** @@ -273,13 +273,16 @@ private[spark] class ExternalSorter[K, V, C]( // Flush the disk writer's contents to disk, and update relevant variables. // The writer is closed at the end of this process, and cannot be reused. def flush() = { - writer.commitAndClose() - val bytesWritten = writer.bytesWritten + val w = writer + writer = null + w.commitAndClose() + val bytesWritten = w.bytesWritten batchSizes.append(bytesWritten) _diskBytesSpilled += bytesWritten objectsWritten = 0 } + var success = false try { val it = collection.destructiveSortedIterator(partitionKeyComparator) while (it.hasNext) { @@ -299,13 +302,23 @@ private[spark] class ExternalSorter[K, V, C]( } if (objectsWritten > 0) { flush() + } else if (writer != null) { + val w = writer + writer = null + w.revertPartialWritesAndClose() + } + success = true + } finally { + if (!success) { + // This code path only happens if an exception was thrown above before we set success; + // close our stuff and let the exception be thrown further + if (writer != null) { + writer.revertPartialWritesAndClose() + } + if (file.exists()) { + file.delete() + } } - writer.close() - } catch { - case e: Exception => - writer.close() - file.delete() - throw e } if (usingMap) { @@ -472,36 +485,58 @@ private[spark] class ExternalSorter[K, V, C]( * partitions to be requested in order. */ private[this] class SpillReader(spill: SpilledFile) { - val fileStream = new FileInputStream(spill.file) - val bufferedStream = new BufferedInputStream(fileStream, fileBufferSize) + // Serializer batch offsets; size will be batchSize.length + 1 + val batchOffsets = spill.serializerBatchSizes.scanLeft(0L)(_ + _) // Track which partition and which batch stream we're in. These will be the indices of // the next element we will read. We'll also store the last partition read so that // readNextPartition() can figure out what partition that was from. var partitionId = 0 var indexInPartition = 0L - var batchStreamsRead = 0 + var batchId = 0 var indexInBatch = 0 var lastPartitionId = 0 skipToNextPartition() - // An intermediate stream that reads from exactly one batch + + // Intermediate file and deserializer streams that read from exactly one batch // This guards against pre-fetching and other arbitrary behavior of higher level streams - var batchStream = nextBatchStream() - var compressedStream = blockManager.wrapForCompression(spill.blockId, batchStream) - var deserStream = serInstance.deserializeStream(compressedStream) + var fileStream: FileInputStream = null + var deserializeStream = nextBatchStream() // Also sets fileStream + var nextItem: (K, C) = null var finished = false /** Construct a stream that only reads from the next batch */ - def nextBatchStream(): InputStream = { - if (batchStreamsRead < spill.serializerBatchSizes.length) { - batchStreamsRead += 1 - ByteStreams.limit(bufferedStream, spill.serializerBatchSizes(batchStreamsRead - 1)) + def nextBatchStream(): DeserializationStream = { + // Note that batchOffsets.length = numBatches + 1 since we did a scan above; check whether + // we're still in a valid batch. + if (batchId < batchOffsets.length - 1) { + if (deserializeStream != null) { + deserializeStream.close() + fileStream.close() + deserializeStream = null + fileStream = null + } + + val start = batchOffsets(batchId) + fileStream = new FileInputStream(spill.file) + fileStream.getChannel.position(start) + batchId += 1 + + val end = batchOffsets(batchId) + + assert(end >= start, "start = " + start + ", end = " + end + + ", batchOffsets = " + batchOffsets.mkString("[", ", ", "]")) + + val bufferedStream = new BufferedInputStream(ByteStreams.limit(fileStream, end - start)) + val compressedStream = blockManager.wrapForCompression(spill.blockId, bufferedStream) + serInstance.deserializeStream(compressedStream) } else { - // No more batches left; give an empty stream - bufferedStream + // No more batches left + cleanup() + null } } @@ -525,19 +560,17 @@ private[spark] class ExternalSorter[K, V, C]( * If no more pairs are left, return null. */ private def readNextItem(): (K, C) = { - if (finished) { + if (finished || deserializeStream == null) { return null } - val k = deserStream.readObject().asInstanceOf[K] - val c = deserStream.readObject().asInstanceOf[C] + val k = deserializeStream.readObject().asInstanceOf[K] + val c = deserializeStream.readObject().asInstanceOf[C] lastPartitionId = partitionId // Start reading the next batch if we're done with this one indexInBatch += 1 if (indexInBatch == serializerBatchSize) { - batchStream = nextBatchStream() - compressedStream = blockManager.wrapForCompression(spill.blockId, batchStream) - deserStream = serInstance.deserializeStream(compressedStream) indexInBatch = 0 + deserializeStream = nextBatchStream() } // Update the partition location of the element we're reading indexInPartition += 1 @@ -545,7 +578,9 @@ private[spark] class ExternalSorter[K, V, C]( // If we've finished reading the last partition, remember that we're done if (partitionId == numPartitions) { finished = true - deserStream.close() + if (deserializeStream != null) { + deserializeStream.close() + } } (k, c) } @@ -578,6 +613,17 @@ private[spark] class ExternalSorter[K, V, C]( item } } + + // Clean up our open streams and put us in a state where we can't read any more data + def cleanup() { + batchId = batchOffsets.length // Prevent reading any other batch + val ds = deserializeStream + deserializeStream = null + fileStream = null + ds.close() + // NOTE: We don't do file.delete() here because that is done in ExternalSorter.stop(). + // This should also be fixed in ExternalAppendOnlyMap. + } } /** diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala index 7de5df6e1c8bd..04d7338488628 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala @@ -30,8 +30,19 @@ class ExternalAppendOnlyMapSuite extends FunSuite with LocalSparkContext { private def mergeValue(buffer: ArrayBuffer[Int], i: Int) = buffer += i private def mergeCombiners(buf1: ArrayBuffer[Int], buf2: ArrayBuffer[Int]) = buf1 ++= buf2 + private def createSparkConf(loadDefaults: Boolean): SparkConf = { + val conf = new SparkConf(loadDefaults) + // Make the Java serializer write a reset instruction (TC_RESET) after each object to test + // for a bug we had with bytes written past the last object in a batch (SPARK-2792) + conf.set("spark.serializer.objectStreamReset", "1") + conf.set("spark.serializer", "org.apache.spark.serializer.JavaSerializer") + // Ensure that we actually have multiple batches per spill file + conf.set("spark.shuffle.spill.batchSize", "10") + conf + } + test("simple insert") { - val conf = new SparkConf(false) + val conf = createSparkConf(false) sc = new SparkContext("local", "test", conf) val map = new ExternalAppendOnlyMap[Int, Int, ArrayBuffer[Int]](createCombiner, @@ -57,7 +68,7 @@ class ExternalAppendOnlyMapSuite extends FunSuite with LocalSparkContext { } test("insert with collision") { - val conf = new SparkConf(false) + val conf = createSparkConf(false) sc = new SparkContext("local", "test", conf) val map = new ExternalAppendOnlyMap[Int, Int, ArrayBuffer[Int]](createCombiner, @@ -80,7 +91,7 @@ class ExternalAppendOnlyMapSuite extends FunSuite with LocalSparkContext { } test("ordering") { - val conf = new SparkConf(false) + val conf = createSparkConf(false) sc = new SparkContext("local", "test", conf) val map1 = new ExternalAppendOnlyMap[Int, Int, ArrayBuffer[Int]](createCombiner, @@ -125,7 +136,7 @@ class ExternalAppendOnlyMapSuite extends FunSuite with LocalSparkContext { } test("null keys and values") { - val conf = new SparkConf(false) + val conf = createSparkConf(false) sc = new SparkContext("local", "test", conf) val map = new ExternalAppendOnlyMap[Int, Int, ArrayBuffer[Int]](createCombiner, @@ -166,7 +177,7 @@ class ExternalAppendOnlyMapSuite extends FunSuite with LocalSparkContext { } test("simple aggregator") { - val conf = new SparkConf(false) + val conf = createSparkConf(false) sc = new SparkContext("local", "test", conf) // reduceByKey @@ -181,7 +192,7 @@ class ExternalAppendOnlyMapSuite extends FunSuite with LocalSparkContext { } test("simple cogroup") { - val conf = new SparkConf(false) + val conf = createSparkConf(false) sc = new SparkContext("local", "test", conf) val rdd1 = sc.parallelize(1 to 4).map(i => (i, i)) val rdd2 = sc.parallelize(1 to 4).map(i => (i%2, i)) @@ -199,7 +210,7 @@ class ExternalAppendOnlyMapSuite extends FunSuite with LocalSparkContext { } test("spilling") { - val conf = new SparkConf(true) // Load defaults, otherwise SPARK_HOME is not found + val conf = createSparkConf(true) // Load defaults, otherwise SPARK_HOME is not found conf.set("spark.shuffle.memoryFraction", "0.001") sc = new SparkContext("local-cluster[1,1,512]", "test", conf) @@ -249,7 +260,7 @@ class ExternalAppendOnlyMapSuite extends FunSuite with LocalSparkContext { } test("spilling with hash collisions") { - val conf = new SparkConf(true) + val conf = createSparkConf(true) conf.set("spark.shuffle.memoryFraction", "0.001") sc = new SparkContext("local-cluster[1,1,512]", "test", conf) @@ -304,7 +315,7 @@ class ExternalAppendOnlyMapSuite extends FunSuite with LocalSparkContext { } test("spilling with many hash collisions") { - val conf = new SparkConf(true) + val conf = createSparkConf(true) conf.set("spark.shuffle.memoryFraction", "0.0001") sc = new SparkContext("local-cluster[1,1,512]", "test", conf) @@ -329,7 +340,7 @@ class ExternalAppendOnlyMapSuite extends FunSuite with LocalSparkContext { } test("spilling with hash collisions using the Int.MaxValue key") { - val conf = new SparkConf(true) + val conf = createSparkConf(true) conf.set("spark.shuffle.memoryFraction", "0.001") sc = new SparkContext("local-cluster[1,1,512]", "test", conf) @@ -347,7 +358,7 @@ class ExternalAppendOnlyMapSuite extends FunSuite with LocalSparkContext { } test("spilling with null keys and values") { - val conf = new SparkConf(true) + val conf = createSparkConf(true) conf.set("spark.shuffle.memoryFraction", "0.001") sc = new SparkContext("local-cluster[1,1,512]", "test", conf) diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala index 65a71e5a83698..57dcb4ffabac1 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala @@ -25,6 +25,17 @@ import org.apache.spark._ import org.apache.spark.SparkContext._ class ExternalSorterSuite extends FunSuite with LocalSparkContext { + private def createSparkConf(loadDefaults: Boolean): SparkConf = { + val conf = new SparkConf(loadDefaults) + // Make the Java serializer write a reset instruction (TC_RESET) after each object to test + // for a bug we had with bytes written past the last object in a batch (SPARK-2792) + conf.set("spark.serializer.objectStreamReset", "1") + conf.set("spark.serializer", "org.apache.spark.serializer.JavaSerializer") + // Ensure that we actually have multiple batches per spill file + conf.set("spark.shuffle.spill.batchSize", "10") + conf + } + test("empty data stream") { val conf = new SparkConf(false) conf.set("spark.shuffle.memoryFraction", "0.001") @@ -60,7 +71,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext { } test("few elements per partition") { - val conf = new SparkConf(false) + val conf = createSparkConf(false) conf.set("spark.shuffle.memoryFraction", "0.001") conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") sc = new SparkContext("local", "test", conf) @@ -102,7 +113,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext { } test("empty partitions with spilling") { - val conf = new SparkConf(false) + val conf = createSparkConf(false) conf.set("spark.shuffle.memoryFraction", "0.001") conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") sc = new SparkContext("local", "test", conf) @@ -127,7 +138,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext { } test("spilling in local cluster") { - val conf = new SparkConf(true) // Load defaults, otherwise SPARK_HOME is not found + val conf = createSparkConf(true) // Load defaults, otherwise SPARK_HOME is not found conf.set("spark.shuffle.memoryFraction", "0.001") conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") sc = new SparkContext("local-cluster[1,1,512]", "test", conf) @@ -198,7 +209,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext { } test("spilling in local cluster with many reduce tasks") { - val conf = new SparkConf(true) // Load defaults, otherwise SPARK_HOME is not found + val conf = createSparkConf(true) // Load defaults, otherwise SPARK_HOME is not found conf.set("spark.shuffle.memoryFraction", "0.001") conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") sc = new SparkContext("local-cluster[2,1,512]", "test", conf) @@ -269,7 +280,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext { } test("cleanup of intermediate files in sorter") { - val conf = new SparkConf(true) // Load defaults, otherwise SPARK_HOME is not found + val conf = createSparkConf(true) // Load defaults, otherwise SPARK_HOME is not found conf.set("spark.shuffle.memoryFraction", "0.001") conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") sc = new SparkContext("local", "test", conf) @@ -290,7 +301,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext { } test("cleanup of intermediate files in sorter if there are errors") { - val conf = new SparkConf(true) // Load defaults, otherwise SPARK_HOME is not found + val conf = createSparkConf(true) // Load defaults, otherwise SPARK_HOME is not found conf.set("spark.shuffle.memoryFraction", "0.001") conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") sc = new SparkContext("local", "test", conf) @@ -311,7 +322,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext { } test("cleanup of intermediate files in shuffle") { - val conf = new SparkConf(false) + val conf = createSparkConf(false) conf.set("spark.shuffle.memoryFraction", "0.001") conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") sc = new SparkContext("local", "test", conf) @@ -326,7 +337,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext { } test("cleanup of intermediate files in shuffle with errors") { - val conf = new SparkConf(false) + val conf = createSparkConf(false) conf.set("spark.shuffle.memoryFraction", "0.001") conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") sc = new SparkContext("local", "test", conf) @@ -348,7 +359,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext { } test("no partial aggregation or sorting") { - val conf = new SparkConf(false) + val conf = createSparkConf(false) conf.set("spark.shuffle.memoryFraction", "0.001") conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") sc = new SparkContext("local", "test", conf) @@ -363,7 +374,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext { } test("partial aggregation without spill") { - val conf = new SparkConf(false) + val conf = createSparkConf(false) conf.set("spark.shuffle.memoryFraction", "0.001") conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") sc = new SparkContext("local", "test", conf) @@ -379,7 +390,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext { } test("partial aggregation with spill, no ordering") { - val conf = new SparkConf(false) + val conf = createSparkConf(false) conf.set("spark.shuffle.memoryFraction", "0.001") conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") sc = new SparkContext("local", "test", conf) @@ -395,7 +406,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext { } test("partial aggregation with spill, with ordering") { - val conf = new SparkConf(false) + val conf = createSparkConf(false) conf.set("spark.shuffle.memoryFraction", "0.001") conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") sc = new SparkContext("local", "test", conf) @@ -412,7 +423,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext { } test("sorting without aggregation, no spill") { - val conf = new SparkConf(false) + val conf = createSparkConf(false) conf.set("spark.shuffle.memoryFraction", "0.001") conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") sc = new SparkContext("local", "test", conf) @@ -429,7 +440,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext { } test("sorting without aggregation, with spill") { - val conf = new SparkConf(false) + val conf = createSparkConf(false) conf.set("spark.shuffle.memoryFraction", "0.001") conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") sc = new SparkContext("local", "test", conf) @@ -446,7 +457,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext { } test("spilling with hash collisions") { - val conf = new SparkConf(true) + val conf = createSparkConf(true) conf.set("spark.shuffle.memoryFraction", "0.001") sc = new SparkContext("local-cluster[1,1,512]", "test", conf) @@ -503,7 +514,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext { } test("spilling with many hash collisions") { - val conf = new SparkConf(true) + val conf = createSparkConf(true) conf.set("spark.shuffle.memoryFraction", "0.0001") sc = new SparkContext("local-cluster[1,1,512]", "test", conf) @@ -526,7 +537,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext { } test("spilling with hash collisions using the Int.MaxValue key") { - val conf = new SparkConf(true) + val conf = createSparkConf(true) conf.set("spark.shuffle.memoryFraction", "0.001") sc = new SparkContext("local-cluster[1,1,512]", "test", conf) @@ -547,7 +558,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext { } test("spilling with null keys and values") { - val conf = new SparkConf(true) + val conf = createSparkConf(true) conf.set("spark.shuffle.memoryFraction", "0.001") sc = new SparkContext("local-cluster[1,1,512]", "test", conf) diff --git a/docs/configuration.md b/docs/configuration.md index 2a71d7b820e5f..870343f1c0bd2 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -385,7 +385,7 @@ Apart from these, the following properties are also available, and may be useful When serializing using org.apache.spark.serializer.JavaSerializer, the serializer caches objects to prevent writing redundant data, however that stops garbage collection of those objects. By calling 'reset' you flush that info from the serializer, and allow old - objects to be collected. To turn off this periodic reset set it to a value <= 0. + objects to be collected. To turn off this periodic reset set it to -1. By default it will reset the serializer every 100 objects. From 2225d18a751b7a4470a93f3d9edebe0d33df75c8 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 4 Aug 2014 15:54:52 -0700 Subject: [PATCH 027/231] [SPARK-1687] [PySpark] fix unit tests related to pickable namedtuple serializer is imported multiple times during doctests, so it's better to make _hijack_namedtuple() safe to be called multiple times. Author: Davies Liu Closes #1771 from davies/fix and squashes the following commits: 1a9e336 [Davies Liu] fix unit tests (cherry picked from commit 9fd82dbbcb8b10debbe95f1acab53ae8b340f38e) Signed-off-by: Josh Rosen --- python/pyspark/serializers.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 1b52c144df087..a10f85b55ad30 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -297,8 +297,11 @@ def __reduce__(self): def _hijack_namedtuple(): """ Hack namedtuple() to make it picklable """ - global _old_namedtuple # or it will put in closure + # hijack only one time + if hasattr(collections.namedtuple, "__hijack"): + return + global _old_namedtuple # or it will put in closure def _copy_func(f): return types.FunctionType(f.func_code, f.func_globals, f.func_name, f.func_defaults, f.func_closure) @@ -313,6 +316,7 @@ def namedtuple(name, fields, verbose=False, rename=False): collections.namedtuple.func_globals["_old_namedtuple"] = _old_namedtuple collections.namedtuple.func_globals["_hack_namedtuple"] = _hack_namedtuple collections.namedtuple.func_code = namedtuple.func_code + collections.namedtuple.__hijack = 1 # hack the cls already generated by namedtuple # those created in other module can be pickled as normal, From 4ed7b5a2ff08eccf23d90990a4d7a2663efaf204 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 4 Aug 2014 20:39:18 -0700 Subject: [PATCH 028/231] [SPARK-2323] Exception in accumulator update should not crash DAGScheduler & SparkContext Author: Reynold Xin Closes #1772 from rxin/accumulator-dagscheduler and squashes the following commits: 6a58520 [Reynold Xin] [SPARK-2323] Exception in accumulator update should not crash DAGScheduler & SparkContext. (cherry picked from commit 05bf4e4aff0d052a53d3e64c43688f07e27fec50) Signed-off-by: Reynold Xin --- .../org/apache/spark/scheduler/DAGScheduler.scala | 9 +++++++-- .../apache/spark/scheduler/DAGSchedulerSuite.scala | 11 +++-------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index d87c3048985fc..9fa3a4e9c71ae 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -904,8 +904,13 @@ class DAGScheduler( event.reason match { case Success => if (event.accumUpdates != null) { - // TODO: fail the stage if the accumulator update fails... - Accumulators.add(event.accumUpdates) // TODO: do this only if task wasn't resubmitted + try { + Accumulators.add(event.accumUpdates) + } catch { + // If we see an exception during accumulator update, just log the error and move on. + case e: Exception => + logError(s"Failed to update accumulators for $task", e) + } } stage.pendingTasks -= task task match { diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 36e238b4c9434..8c1b0fed11f72 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -622,8 +622,7 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F assertDataStructuresEmpty } - // TODO: Fix this and un-ignore the test. - ignore("misbehaved accumulator should not crash DAGScheduler and SparkContext") { + test("misbehaved accumulator should not crash DAGScheduler and SparkContext") { val acc = new Accumulator[Int](0, new AccumulatorParam[Int] { override def addAccumulator(t1: Int, t2: Int): Int = t1 + t2 override def zero(initialValue: Int): Int = 0 @@ -633,14 +632,10 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F }) // Run this on executors - intercept[SparkDriverExecutionException] { - sc.parallelize(1 to 10, 2).foreach { item => acc.add(1) } - } + sc.parallelize(1 to 10, 2).foreach { item => acc.add(1) } // Run this within a local thread - intercept[SparkDriverExecutionException] { - sc.parallelize(1 to 10, 2).map { item => acc.add(1) }.take(1) - } + sc.parallelize(1 to 10, 2).map { item => acc.add(1) }.take(1) // Make sure we can still run local commands as well as cluster commands. assert(sc.parallelize(1 to 10, 2).count() === 10) From a0922854909176a24cc689a7e8595303dcf62f3f Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Mon, 4 Aug 2014 23:27:53 -0700 Subject: [PATCH 029/231] SPARK-2685. Update ExternalAppendOnlyMap to avoid buffer.remove() Replaces this with an O(1) operation that does not have to shift over the whole tail of the array into the gap produced by the element removed. Author: Matei Zaharia Closes #1773 from mateiz/SPARK-2685 and squashes the following commits: 1ea028a [Matei Zaharia] Update comments in StreamBuffer and EAOM, and reuse ArrayBuffers eb1abfd [Matei Zaharia] Update ExternalAppendOnlyMap to avoid buffer.remove() (cherry picked from commit 066765d60d21b6b9943862b788e4a4bd07396e6c) Signed-off-by: Matei Zaharia --- .../collection/ExternalAppendOnlyMap.scala | 50 +++++++++++++------ 1 file changed, 35 insertions(+), 15 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala index 5d10a1f84493c..1f7d2dc838ebc 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala @@ -286,30 +286,32 @@ class ExternalAppendOnlyMap[K, V, C]( private val inputStreams = (Seq(sortedMap) ++ spilledMaps).map(it => it.buffered) inputStreams.foreach { it => - val kcPairs = getMorePairs(it) + val kcPairs = new ArrayBuffer[(K, C)] + readNextHashCode(it, kcPairs) if (kcPairs.length > 0) { mergeHeap.enqueue(new StreamBuffer(it, kcPairs)) } } /** - * Fetch from the given iterator until a key of different hash is retrieved. + * Fill a buffer with the next set of keys with the same hash code from a given iterator. We + * read streams one hash code at a time to ensure we don't miss elements when they are merged. + * + * Assumes the given iterator is in sorted order of hash code. * - * In the event of key hash collisions, this ensures no pairs are hidden from being merged. - * Assume the given iterator is in sorted order. + * @param it iterator to read from + * @param buf buffer to write the results into */ - private def getMorePairs(it: BufferedIterator[(K, C)]): ArrayBuffer[(K, C)] = { - val kcPairs = new ArrayBuffer[(K, C)] + private def readNextHashCode(it: BufferedIterator[(K, C)], buf: ArrayBuffer[(K, C)]): Unit = { if (it.hasNext) { var kc = it.next() - kcPairs += kc + buf += kc val minHash = hashKey(kc) while (it.hasNext && it.head._1.hashCode() == minHash) { kc = it.next() - kcPairs += kc + buf += kc } } - kcPairs } /** @@ -321,7 +323,9 @@ class ExternalAppendOnlyMap[K, V, C]( while (i < buffer.pairs.length) { val pair = buffer.pairs(i) if (pair._1 == key) { - buffer.pairs.remove(i) + // Note that there's at most one pair in the buffer with a given key, since we always + // merge stuff in a map before spilling, so it's safe to return after the first we find + removeFromBuffer(buffer.pairs, i) return mergeCombiners(baseCombiner, pair._2) } i += 1 @@ -329,6 +333,19 @@ class ExternalAppendOnlyMap[K, V, C]( baseCombiner } + /** + * Remove the index'th element from an ArrayBuffer in constant time, swapping another element + * into its place. This is more efficient than the ArrayBuffer.remove method because it does + * not have to shift all the elements in the array over. It works for our array buffers because + * we don't care about the order of elements inside, we just want to search them for a key. + */ + private def removeFromBuffer[T](buffer: ArrayBuffer[T], index: Int): T = { + val elem = buffer(index) + buffer(index) = buffer(buffer.size - 1) // This also works if index == buffer.size - 1 + buffer.reduceToSize(buffer.size - 1) + elem + } + /** * Return true if there exists an input stream that still has unvisited pairs. */ @@ -346,7 +363,7 @@ class ExternalAppendOnlyMap[K, V, C]( val minBuffer = mergeHeap.dequeue() val minPairs = minBuffer.pairs val minHash = minBuffer.minKeyHash - val minPair = minPairs.remove(0) + val minPair = removeFromBuffer(minPairs, 0) val minKey = minPair._1 var minCombiner = minPair._2 assert(hashKey(minPair) == minHash) @@ -363,7 +380,7 @@ class ExternalAppendOnlyMap[K, V, C]( // Repopulate each visited stream buffer and add it back to the queue if it is non-empty mergedBuffers.foreach { buffer => if (buffer.isEmpty) { - buffer.pairs ++= getMorePairs(buffer.iterator) + readNextHashCode(buffer.iterator, buffer.pairs) } if (!buffer.isEmpty) { mergeHeap.enqueue(buffer) @@ -375,10 +392,13 @@ class ExternalAppendOnlyMap[K, V, C]( /** * A buffer for streaming from a map iterator (in-memory or on-disk) sorted by key hash. - * Each buffer maintains the lowest-ordered keys in the corresponding iterator. Due to - * hash collisions, it is possible for multiple keys to be "tied" for being the lowest. + * Each buffer maintains all of the key-value pairs with what is currently the lowest hash + * code among keys in the stream. There may be multiple keys if there are hash collisions. + * Note that because when we spill data out, we only spill one value for each key, there is + * at most one element for each key. * - * StreamBuffers are ordered by the minimum key hash found across all of their own pairs. + * StreamBuffers are ordered by the minimum key hash currently available in their stream so + * that we can put them into a heap and sort that. */ private class StreamBuffer( val iterator: BufferedIterator[(K, C)], From d13d253fea6dd1f666c4c94087173f734843f2b5 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Mon, 4 Aug 2014 23:41:03 -0700 Subject: [PATCH 030/231] SPARK-2711. Create a ShuffleMemoryManager to track memory for all spilling collections This tracks memory properly if there are multiple spilling collections in the same task (which was a problem before), and also implements an algorithm that lets each thread grow up to 1 / 2N of the memory pool (where N is the number of threads) before spilling, which avoids an inefficiency with small spills we had before (some threads would spill many times at 0-1 MB because the pool was allocated elsewhere). Author: Matei Zaharia Closes #1707 from mateiz/spark-2711 and squashes the following commits: debf75b [Matei Zaharia] Review comments 24f28f3 [Matei Zaharia] Small rename c8f3a8b [Matei Zaharia] Update ShuffleMemoryManager to be able to partially grant requests 315e3a5 [Matei Zaharia] Some review comments b810120 [Matei Zaharia] Create central manager to track memory for all spilling collections (cherry picked from commit 4fde28c2063f673ec7f51d514ba62a73321960a1) Signed-off-by: Matei Zaharia --- .../scala/org/apache/spark/SparkEnv.scala | 10 +- .../org/apache/spark/executor/Executor.scala | 5 +- .../spark/shuffle/ShuffleMemoryManager.scala | 125 ++++++++ .../collection/ExternalAppendOnlyMap.scala | 48 +-- .../util/collection/ExternalSorter.scala | 49 +-- .../shuffle/ShuffleMemoryManagerSuite.scala | 294 ++++++++++++++++++ 6 files changed, 450 insertions(+), 81 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala create mode 100644 core/src/test/scala/org/apache/spark/shuffle/ShuffleMemoryManagerSuite.scala diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 0bce531aaba3e..dd8e4ac66dc66 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -35,7 +35,7 @@ import org.apache.spark.metrics.MetricsSystem import org.apache.spark.network.ConnectionManager import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.Serializer -import org.apache.spark.shuffle.ShuffleManager +import org.apache.spark.shuffle.{ShuffleMemoryManager, ShuffleManager} import org.apache.spark.storage._ import org.apache.spark.util.{AkkaUtils, Utils} @@ -66,12 +66,9 @@ class SparkEnv ( val httpFileServer: HttpFileServer, val sparkFilesDir: String, val metricsSystem: MetricsSystem, + val shuffleMemoryManager: ShuffleMemoryManager, val conf: SparkConf) extends Logging { - // A mapping of thread ID to amount of memory, in bytes, used for shuffle aggregations - // All accesses should be manually synchronized - val shuffleMemoryMap = mutable.HashMap[Long, Long]() - private val pythonWorkers = mutable.HashMap[(String, Map[String, String]), PythonWorkerFactory]() // A general, soft-reference map for metadata needed during HadoopRDD split computation @@ -252,6 +249,8 @@ object SparkEnv extends Logging { val shuffleManager = instantiateClass[ShuffleManager]( "spark.shuffle.manager", "org.apache.spark.shuffle.hash.HashShuffleManager") + val shuffleMemoryManager = new ShuffleMemoryManager(conf) + // Warn about deprecated spark.cache.class property if (conf.contains("spark.cache.class")) { logWarning("The spark.cache.class property is no longer being used! Specify storage " + @@ -273,6 +272,7 @@ object SparkEnv extends Logging { httpFileServer, sparkFilesDir, metricsSystem, + shuffleMemoryManager, conf) } diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 1bb1b4aae91bb..c2b9c660ddaec 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -276,10 +276,7 @@ private[spark] class Executor( } } finally { // Release memory used by this thread for shuffles - val shuffleMemoryMap = env.shuffleMemoryMap - shuffleMemoryMap.synchronized { - shuffleMemoryMap.remove(Thread.currentThread().getId) - } + env.shuffleMemoryManager.releaseMemoryForThisThread() // Release memory used by this thread for unrolling blocks env.blockManager.memoryStore.releaseUnrollMemoryForThisThread() runningTasks.remove(taskId) diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala new file mode 100644 index 0000000000000..ee91a368b76ea --- /dev/null +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle + +import scala.collection.mutable + +import org.apache.spark.{Logging, SparkException, SparkConf} + +/** + * Allocates a pool of memory to task threads for use in shuffle operations. Each disk-spilling + * collection (ExternalAppendOnlyMap or ExternalSorter) used by these tasks can acquire memory + * from this pool and release it as it spills data out. When a task ends, all its memory will be + * released by the Executor. + * + * This class tries to ensure that each thread gets a reasonable share of memory, instead of some + * thread ramping up to a large amount first and then causing others to spill to disk repeatedly. + * If there are N threads, it ensures that each thread can acquire at least 1 / 2N of the memory + * before it has to spill, and at most 1 / N. Because N varies dynamically, we keep track of the + * set of active threads and redo the calculations of 1 / 2N and 1 / N in waiting threads whenever + * this set changes. This is all done by synchronizing access on "this" to mutate state and using + * wait() and notifyAll() to signal changes. + */ +private[spark] class ShuffleMemoryManager(maxMemory: Long) extends Logging { + private val threadMemory = new mutable.HashMap[Long, Long]() // threadId -> memory bytes + + def this(conf: SparkConf) = this(ShuffleMemoryManager.getMaxMemory(conf)) + + /** + * Try to acquire up to numBytes memory for the current thread, and return the number of bytes + * obtained, or 0 if none can be allocated. This call may block until there is enough free memory + * in some situations, to make sure each thread has a chance to ramp up to at least 1 / 2N of the + * total memory pool (where N is the # of active threads) before it is forced to spill. This can + * happen if the number of threads increases but an older thread had a lot of memory already. + */ + def tryToAcquire(numBytes: Long): Long = synchronized { + val threadId = Thread.currentThread().getId + assert(numBytes > 0, "invalid number of bytes requested: " + numBytes) + + // Add this thread to the threadMemory map just so we can keep an accurate count of the number + // of active threads, to let other threads ramp down their memory in calls to tryToAcquire + if (!threadMemory.contains(threadId)) { + threadMemory(threadId) = 0L + notifyAll() // Will later cause waiting threads to wake up and check numThreads again + } + + // Keep looping until we're either sure that we don't want to grant this request (because this + // thread would have more than 1 / numActiveThreads of the memory) or we have enough free + // memory to give it (we always let each thread get at least 1 / (2 * numActiveThreads)). + while (true) { + val numActiveThreads = threadMemory.keys.size + val curMem = threadMemory(threadId) + val freeMemory = maxMemory - threadMemory.values.sum + + // How much we can grant this thread; don't let it grow to more than 1 / numActiveThreads + val maxToGrant = math.min(numBytes, (maxMemory / numActiveThreads) - curMem) + + if (curMem < maxMemory / (2 * numActiveThreads)) { + // We want to let each thread get at least 1 / (2 * numActiveThreads) before blocking; + // if we can't give it this much now, wait for other threads to free up memory + // (this happens if older threads allocated lots of memory before N grew) + if (freeMemory >= math.min(maxToGrant, maxMemory / (2 * numActiveThreads) - curMem)) { + val toGrant = math.min(maxToGrant, freeMemory) + threadMemory(threadId) += toGrant + return toGrant + } else { + logInfo(s"Thread $threadId waiting for at least 1/2N of shuffle memory pool to be free") + wait() + } + } else { + // Only give it as much memory as is free, which might be none if it reached 1 / numThreads + val toGrant = math.min(maxToGrant, freeMemory) + threadMemory(threadId) += toGrant + return toGrant + } + } + 0L // Never reached + } + + /** Release numBytes bytes for the current thread. */ + def release(numBytes: Long): Unit = synchronized { + val threadId = Thread.currentThread().getId + val curMem = threadMemory.getOrElse(threadId, 0L) + if (curMem < numBytes) { + throw new SparkException( + s"Internal error: release called on ${numBytes} bytes but thread only has ${curMem}") + } + threadMemory(threadId) -= numBytes + notifyAll() // Notify waiters who locked "this" in tryToAcquire that memory has been freed + } + + /** Release all memory for the current thread and mark it as inactive (e.g. when a task ends). */ + def releaseMemoryForThisThread(): Unit = synchronized { + val threadId = Thread.currentThread().getId + threadMemory.remove(threadId) + notifyAll() // Notify waiters who locked "this" in tryToAcquire that memory has been freed + } +} + +private object ShuffleMemoryManager { + /** + * Figure out the shuffle memory limit from a SparkConf. We currently have both a fraction + * of the memory pool and a safety factor since collections can sometimes grow bigger than + * the size we target before we estimate their sizes again. + */ + def getMaxMemory(conf: SparkConf): Long = { + val memoryFraction = conf.getDouble("spark.shuffle.memoryFraction", 0.2) + val safetyFraction = conf.getDouble("spark.shuffle.safetyFraction", 0.8) + (Runtime.getRuntime.maxMemory * memoryFraction * safetyFraction).toLong + } +} diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala index 1f7d2dc838ebc..cc0423856cefb 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala @@ -71,13 +71,7 @@ class ExternalAppendOnlyMap[K, V, C]( private val spilledMaps = new ArrayBuffer[DiskMapIterator] private val sparkConf = SparkEnv.get.conf private val diskBlockManager = blockManager.diskBlockManager - - // Collective memory threshold shared across all running tasks - private val maxMemoryThreshold = { - val memoryFraction = sparkConf.getDouble("spark.shuffle.memoryFraction", 0.2) - val safetyFraction = sparkConf.getDouble("spark.shuffle.safetyFraction", 0.8) - (Runtime.getRuntime.maxMemory * memoryFraction * safetyFraction).toLong - } + private val shuffleMemoryManager = SparkEnv.get.shuffleMemoryManager // Number of pairs inserted since last spill; note that we count them even if a value is merged // with a previous key in case we're doing something like groupBy where the result grows @@ -140,28 +134,15 @@ class ExternalAppendOnlyMap[K, V, C]( if (elementsRead > trackMemoryThreshold && elementsRead % 32 == 0 && currentMap.estimateSize() >= myMemoryThreshold) { - val currentSize = currentMap.estimateSize() - var shouldSpill = false - val shuffleMemoryMap = SparkEnv.get.shuffleMemoryMap - - // Atomically check whether there is sufficient memory in the global pool for - // this map to grow and, if possible, allocate the required amount - shuffleMemoryMap.synchronized { - val threadId = Thread.currentThread().getId - val previouslyOccupiedMemory = shuffleMemoryMap.get(threadId) - val availableMemory = maxMemoryThreshold - - (shuffleMemoryMap.values.sum - previouslyOccupiedMemory.getOrElse(0L)) - - // Try to allocate at least 2x more memory, otherwise spill - shouldSpill = availableMemory < currentSize * 2 - if (!shouldSpill) { - shuffleMemoryMap(threadId) = currentSize * 2 - myMemoryThreshold = currentSize * 2 - } - } - // Do not synchronize spills - if (shouldSpill) { - spill(currentSize) + // Claim up to double our current memory from the shuffle memory pool + val currentMemory = currentMap.estimateSize() + val amountToRequest = 2 * currentMemory - myMemoryThreshold + val granted = shuffleMemoryManager.tryToAcquire(amountToRequest) + myMemoryThreshold += granted + if (myMemoryThreshold <= currentMemory) { + // We were granted too little memory to grow further (either tryToAcquire returned 0, + // or we already had more memory than myMemoryThreshold); spill the current collection + spill(currentMemory) // Will also release memory back to ShuffleMemoryManager } } currentMap.changeValue(curEntry._1, update) @@ -245,12 +226,9 @@ class ExternalAppendOnlyMap[K, V, C]( currentMap = new SizeTrackingAppendOnlyMap[K, C] spilledMaps.append(new DiskMapIterator(file, blockId, batchSizes)) - // Reset the amount of shuffle memory used by this map in the global pool - val shuffleMemoryMap = SparkEnv.get.shuffleMemoryMap - shuffleMemoryMap.synchronized { - shuffleMemoryMap(Thread.currentThread().getId) = 0 - } - myMemoryThreshold = 0 + // Release our memory back to the shuffle pool so that other threads can grab it + shuffleMemoryManager.release(myMemoryThreshold) + myMemoryThreshold = 0L elementsRead = 0 _memoryBytesSpilled += mapSize diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index b04c50bd3e196..101c83b264f63 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -78,6 +78,7 @@ private[spark] class ExternalSorter[K, V, C]( private val blockManager = SparkEnv.get.blockManager private val diskBlockManager = blockManager.diskBlockManager + private val shuffleMemoryManager = SparkEnv.get.shuffleMemoryManager private val ser = Serializer.getSerializer(serializer) private val serInstance = ser.newInstance() @@ -116,13 +117,6 @@ private[spark] class ExternalSorter[K, V, C]( private var _memoryBytesSpilled = 0L private var _diskBytesSpilled = 0L - // Collective memory threshold shared across all running tasks - private val maxMemoryThreshold = { - val memoryFraction = conf.getDouble("spark.shuffle.memoryFraction", 0.2) - val safetyFraction = conf.getDouble("spark.shuffle.safetyFraction", 0.8) - (Runtime.getRuntime.maxMemory * memoryFraction * safetyFraction).toLong - } - // How much of the shared memory pool this collection has claimed private var myMemoryThreshold = 0L @@ -218,31 +212,15 @@ private[spark] class ExternalSorter[K, V, C]( if (elementsRead > trackMemoryThreshold && elementsRead % 32 == 0 && collection.estimateSize() >= myMemoryThreshold) { - // TODO: This logic doesn't work if there are two external collections being used in the same - // task (e.g. to read shuffle output and write it out into another shuffle) [SPARK-2711] - - val currentSize = collection.estimateSize() - var shouldSpill = false - val shuffleMemoryMap = SparkEnv.get.shuffleMemoryMap - - // Atomically check whether there is sufficient memory in the global pool for - // us to double our threshold - shuffleMemoryMap.synchronized { - val threadId = Thread.currentThread().getId - val previouslyClaimedMemory = shuffleMemoryMap.get(threadId) - val availableMemory = maxMemoryThreshold - - (shuffleMemoryMap.values.sum - previouslyClaimedMemory.getOrElse(0L)) - - // Try to allocate at least 2x more memory, otherwise spill - shouldSpill = availableMemory < currentSize * 2 - if (!shouldSpill) { - shuffleMemoryMap(threadId) = currentSize * 2 - myMemoryThreshold = currentSize * 2 - } - } - // Do not hold lock during spills - if (shouldSpill) { - spill(currentSize, usingMap) + // Claim up to double our current memory from the shuffle memory pool + val currentMemory = collection.estimateSize() + val amountToRequest = 2 * currentMemory - myMemoryThreshold + val granted = shuffleMemoryManager.tryToAcquire(amountToRequest) + myMemoryThreshold += granted + if (myMemoryThreshold <= currentMemory) { + // We were granted too little memory to grow further (either tryToAcquire returned 0, + // or we already had more memory than myMemoryThreshold); spill the current collection + spill(currentMemory, usingMap) // Will also release memory back to ShuffleMemoryManager } } } @@ -327,11 +305,8 @@ private[spark] class ExternalSorter[K, V, C]( buffer = new SizeTrackingPairBuffer[(Int, K), C] } - // Reset the amount of shuffle memory used by this map in the global pool - val shuffleMemoryMap = SparkEnv.get.shuffleMemoryMap - shuffleMemoryMap.synchronized { - shuffleMemoryMap(Thread.currentThread().getId) = 0 - } + // Release our memory back to the shuffle pool so that other threads can grab it + shuffleMemoryManager.release(myMemoryThreshold) myMemoryThreshold = 0 spills.append(SpilledFile(file, blockId, batchSizes.toArray, elementsPerPartition)) diff --git a/core/src/test/scala/org/apache/spark/shuffle/ShuffleMemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/ShuffleMemoryManagerSuite.scala new file mode 100644 index 0000000000000..d31bc22ee74f7 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/shuffle/ShuffleMemoryManagerSuite.scala @@ -0,0 +1,294 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle + +import org.scalatest.FunSuite +import org.scalatest.concurrent.Timeouts +import org.scalatest.time.SpanSugar._ +import java.util.concurrent.atomic.AtomicBoolean +import java.util.concurrent.CountDownLatch + +class ShuffleMemoryManagerSuite extends FunSuite with Timeouts { + /** Launch a thread with the given body block and return it. */ + private def startThread(name: String)(body: => Unit): Thread = { + val thread = new Thread("ShuffleMemorySuite " + name) { + override def run() { + body + } + } + thread.start() + thread + } + + test("single thread requesting memory") { + val manager = new ShuffleMemoryManager(1000L) + + assert(manager.tryToAcquire(100L) === 100L) + assert(manager.tryToAcquire(400L) === 400L) + assert(manager.tryToAcquire(400L) === 400L) + assert(manager.tryToAcquire(200L) === 100L) + assert(manager.tryToAcquire(100L) === 0L) + assert(manager.tryToAcquire(100L) === 0L) + + manager.release(500L) + assert(manager.tryToAcquire(300L) === 300L) + assert(manager.tryToAcquire(300L) === 200L) + + manager.releaseMemoryForThisThread() + assert(manager.tryToAcquire(1000L) === 1000L) + assert(manager.tryToAcquire(100L) === 0L) + } + + test("two threads requesting full memory") { + // Two threads request 500 bytes first, wait for each other to get it, and then request + // 500 more; we should immediately return 0 as both are now at 1 / N + + val manager = new ShuffleMemoryManager(1000L) + + class State { + var t1Result1 = -1L + var t2Result1 = -1L + var t1Result2 = -1L + var t2Result2 = -1L + } + val state = new State + + val t1 = startThread("t1") { + val r1 = manager.tryToAcquire(500L) + state.synchronized { + state.t1Result1 = r1 + state.notifyAll() + while (state.t2Result1 === -1L) { + state.wait() + } + } + val r2 = manager.tryToAcquire(500L) + state.synchronized { state.t1Result2 = r2 } + } + + val t2 = startThread("t2") { + val r1 = manager.tryToAcquire(500L) + state.synchronized { + state.t2Result1 = r1 + state.notifyAll() + while (state.t1Result1 === -1L) { + state.wait() + } + } + val r2 = manager.tryToAcquire(500L) + state.synchronized { state.t2Result2 = r2 } + } + + failAfter(20 seconds) { + t1.join() + t2.join() + } + + assert(state.t1Result1 === 500L) + assert(state.t2Result1 === 500L) + assert(state.t1Result2 === 0L) + assert(state.t2Result2 === 0L) + } + + + test("threads cannot grow past 1 / N") { + // Two threads request 250 bytes first, wait for each other to get it, and then request + // 500 more; we should only grant 250 bytes to each of them on this second request + + val manager = new ShuffleMemoryManager(1000L) + + class State { + var t1Result1 = -1L + var t2Result1 = -1L + var t1Result2 = -1L + var t2Result2 = -1L + } + val state = new State + + val t1 = startThread("t1") { + val r1 = manager.tryToAcquire(250L) + state.synchronized { + state.t1Result1 = r1 + state.notifyAll() + while (state.t2Result1 === -1L) { + state.wait() + } + } + val r2 = manager.tryToAcquire(500L) + state.synchronized { state.t1Result2 = r2 } + } + + val t2 = startThread("t2") { + val r1 = manager.tryToAcquire(250L) + state.synchronized { + state.t2Result1 = r1 + state.notifyAll() + while (state.t1Result1 === -1L) { + state.wait() + } + } + val r2 = manager.tryToAcquire(500L) + state.synchronized { state.t2Result2 = r2 } + } + + failAfter(20 seconds) { + t1.join() + t2.join() + } + + assert(state.t1Result1 === 250L) + assert(state.t2Result1 === 250L) + assert(state.t1Result2 === 250L) + assert(state.t2Result2 === 250L) + } + + test("threads can block to get at least 1 / 2N memory") { + // t1 grabs 1000 bytes and then waits until t2 is ready to make a request. It sleeps + // for a bit and releases 250 bytes, which should then be greanted to t2. Further requests + // by t2 will return false right away because it now has 1 / 2N of the memory. + + val manager = new ShuffleMemoryManager(1000L) + + class State { + var t1Requested = false + var t2Requested = false + var t1Result = -1L + var t2Result = -1L + var t2Result2 = -1L + var t2WaitTime = 0L + } + val state = new State + + val t1 = startThread("t1") { + state.synchronized { + state.t1Result = manager.tryToAcquire(1000L) + state.t1Requested = true + state.notifyAll() + while (!state.t2Requested) { + state.wait() + } + } + // Sleep a bit before releasing our memory; this is hacky but it would be difficult to make + // sure the other thread blocks for some time otherwise + Thread.sleep(300) + manager.release(250L) + } + + val t2 = startThread("t2") { + state.synchronized { + while (!state.t1Requested) { + state.wait() + } + state.t2Requested = true + state.notifyAll() + } + val startTime = System.currentTimeMillis() + val result = manager.tryToAcquire(250L) + val endTime = System.currentTimeMillis() + state.synchronized { + state.t2Result = result + // A second call should return 0 because we're now already at 1 / 2N + state.t2Result2 = manager.tryToAcquire(100L) + state.t2WaitTime = endTime - startTime + } + } + + failAfter(20 seconds) { + t1.join() + t2.join() + } + + // Both threads should've been able to acquire their memory; the second one will have waited + // until the first one acquired 1000 bytes and then released 250 + state.synchronized { + assert(state.t1Result === 1000L, "t1 could not allocate memory") + assert(state.t2Result === 250L, "t2 could not allocate memory") + assert(state.t2WaitTime > 200, s"t2 waited less than 200 ms (${state.t2WaitTime})") + assert(state.t2Result2 === 0L, "t1 got extra memory the second time") + } + } + + test("releaseMemoryForThisThread") { + // t1 grabs 1000 bytes and then waits until t2 is ready to make a request. It sleeps + // for a bit and releases all its memory. t2 should now be able to grab all the memory. + + val manager = new ShuffleMemoryManager(1000L) + + class State { + var t1Requested = false + var t2Requested = false + var t1Result = -1L + var t2Result1 = -1L + var t2Result2 = -1L + var t2Result3 = -1L + var t2WaitTime = 0L + } + val state = new State + + val t1 = startThread("t1") { + state.synchronized { + state.t1Result = manager.tryToAcquire(1000L) + state.t1Requested = true + state.notifyAll() + while (!state.t2Requested) { + state.wait() + } + } + // Sleep a bit before releasing our memory; this is hacky but it would be difficult to make + // sure the other thread blocks for some time otherwise + Thread.sleep(300) + manager.releaseMemoryForThisThread() + } + + val t2 = startThread("t2") { + state.synchronized { + while (!state.t1Requested) { + state.wait() + } + state.t2Requested = true + state.notifyAll() + } + val startTime = System.currentTimeMillis() + val r1 = manager.tryToAcquire(500L) + val endTime = System.currentTimeMillis() + val r2 = manager.tryToAcquire(500L) + val r3 = manager.tryToAcquire(500L) + state.synchronized { + state.t2Result1 = r1 + state.t2Result2 = r2 + state.t2Result3 = r3 + state.t2WaitTime = endTime - startTime + } + } + + failAfter(20 seconds) { + t1.join() + t2.join() + } + + // Both threads should've been able to acquire their memory; the second one will have waited + // until the first one acquired 1000 bytes and then released all of it + state.synchronized { + assert(state.t1Result === 1000L, "t1 could not allocate memory") + assert(state.t2Result1 === 500L, "t2 didn't get 500 bytes the first time") + assert(state.t2Result2 === 500L, "t2 didn't get 500 bytes the second time") + assert(state.t2Result3 === 0L, s"t2 got more bytes a third time (${state.t2Result3})") + assert(state.t2WaitTime > 200, s"t2 waited less than 200 ms (${state.t2WaitTime})") + } + } +} From 12f99cf5f88faf94d9dbfe85cb72d0010a3a25ac Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Tue, 5 Aug 2014 00:39:07 -0700 Subject: [PATCH 031/231] [SPARK-2857] Correct properties to set Master / Worker ports `master.ui.port` and `worker.ui.port` were never picked up by SparkConf, simply because they are not prefixed with "spark." Unfortunately, this is also currently the documented way of setting these values. Author: Andrew Or Closes #1779 from andrewor14/master-worker-port and squashes the following commits: 8475e95 [Andrew Or] Update docs to reflect changes in configs 4db3d5d [Andrew Or] Stop using configs that don't actually work (cherry picked from commit a646a365e3beb8d0cd7e492e625ce68ee9439a07) Signed-off-by: Patrick Wendell --- .../org/apache/spark/deploy/master/MasterArguments.scala | 4 ++-- .../scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala | 2 +- docs/spark-standalone.md | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala b/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala index a87781fb93850..4b0dbbe543d3f 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala @@ -38,8 +38,8 @@ private[spark] class MasterArguments(args: Array[String], conf: SparkConf) { if (System.getenv("SPARK_MASTER_WEBUI_PORT") != null) { webUiPort = System.getenv("SPARK_MASTER_WEBUI_PORT").toInt } - if (conf.contains("master.ui.port")) { - webUiPort = conf.get("master.ui.port").toInt + if (conf.contains("spark.master.ui.port")) { + webUiPort = conf.get("spark.master.ui.port").toInt } parse(args.toList) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala index 0ad2edba2227f..a9f531e9e4cae 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala @@ -58,6 +58,6 @@ private[spark] object WorkerWebUI { val STATIC_RESOURCE_BASE = SparkUI.STATIC_RESOURCE_DIR def getUIPort(requestedPort: Option[Int], conf: SparkConf): Int = { - requestedPort.getOrElse(conf.getInt("worker.ui.port", WorkerWebUI.DEFAULT_PORT)) + requestedPort.getOrElse(conf.getInt("spark.worker.ui.port", WorkerWebUI.DEFAULT_PORT)) } } diff --git a/docs/spark-standalone.md b/docs/spark-standalone.md index 2fb30765f35e8..293a7ac9bc9aa 100644 --- a/docs/spark-standalone.md +++ b/docs/spark-standalone.md @@ -314,7 +314,7 @@ configure those ports. Standalone Cluster Master 8080 Web UI - master.ui.port + spark.master.ui.port Jetty-based @@ -338,7 +338,7 @@ configure those ports. Worker 8081 Web UI - worker.ui.port + spark.worker.ui.port Jetty-based From 075ba67819b0f250cc176c96f2f5d8eddb0b16ac Mon Sep 17 00:00:00 2001 From: wangfei Date: Tue, 5 Aug 2014 00:51:07 -0700 Subject: [PATCH 032/231] [SPARK-1779] Throw an exception if memory fractions are not between 0 and 1 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Author: wangfei Author: wangfei Closes #714 from scwf/memoryFraction and squashes the following commits: 6e385b9 [wangfei] Update SparkConf.scala da6ee59 [wangfei] add configs 829a195 [wangfei] add indent 717c0ca [wangfei] updated to make more concise fc45476 [wangfei] validate memoryfraction in sparkconf 2e79b3d [wangfei] && => || 43621bd [wangfei] && => || cf38bcf [wangfei] throw IllegalArgumentException 14d18ac [wangfei] throw IllegalArgumentException dff1f0f [wangfei] Update BlockManager.scala 764965f [wangfei] Update ExternalAppendOnlyMap.scala a59d76b [wangfei] Throw exception when memoryFracton is out of range 7b899c2 [wangfei] 【SPARK-1779】 --- .../main/scala/org/apache/spark/SparkConf.scala | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index 38700847c80f4..cce7a23d3b9fc 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -238,6 +238,20 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { } } + // Validate memory fractions + val memoryKeys = Seq( + "spark.storage.memoryFraction", + "spark.shuffle.memoryFraction", + "spark.shuffle.safetyFraction", + "spark.storage.unrollFraction", + "spark.storage.safetyFraction") + for (key <- memoryKeys) { + val value = getDouble(key, 0.5) + if (value > 1 || value < 0) { + throw new IllegalArgumentException("$key should be between 0 and 1 (was '$value').") + } + } + // Check for legacy configs sys.env.get("SPARK_JAVA_OPTS").foreach { value => val warning = From b92a450583989470ff53b62c124d908ad661e29a Mon Sep 17 00:00:00 2001 From: jerryshao Date: Tue, 5 Aug 2014 10:40:28 -0700 Subject: [PATCH 033/231] [SPARK-1022][Streaming] Add Kafka real unit test This PR is a updated version of (https://github.com/apache/spark/pull/557) to actually test sending and receiving data through Kafka, and fix previous flaky issues. @tdas, would you mind reviewing this PR? Thanks a lot. Author: jerryshao Closes #1751 from jerryshao/kafka-unit-test and squashes the following commits: b6a505f [jerryshao] code refactor according to comments 5222330 [jerryshao] Change JavaKafkaStreamSuite to better test it 5525f10 [jerryshao] Fix flaky issue of Kafka real unit test 4559310 [jerryshao] Minor changes for Kafka unit test 860f649 [jerryshao] Minor style changes, and tests ignored due to flakiness 796d4ca [jerryshao] Add real Kafka streaming test --- external/kafka/pom.xml | 6 + .../streaming/kafka/JavaKafkaStreamSuite.java | 125 +++++++++-- .../streaming/kafka/KafkaStreamSuite.scala | 197 ++++++++++++++++-- 3 files changed, 293 insertions(+), 35 deletions(-) diff --git a/external/kafka/pom.xml b/external/kafka/pom.xml index daf03360bc5f5..2aee99949223a 100644 --- a/external/kafka/pom.xml +++ b/external/kafka/pom.xml @@ -70,6 +70,12 @@ + + net.sf.jopt-simple + jopt-simple + 3.2 + test + org.scalatest scalatest_${scala.binary.version} diff --git a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java index 9f8046bf00f8f..0571454c01dae 100644 --- a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java +++ b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java @@ -17,31 +17,118 @@ package org.apache.spark.streaming.kafka; +import java.io.Serializable; import java.util.HashMap; +import java.util.List; + +import scala.Predef; +import scala.Tuple2; +import scala.collection.JavaConverters; + +import junit.framework.Assert; -import org.apache.spark.streaming.api.java.JavaPairReceiverInputDStream; -import org.junit.Test; -import com.google.common.collect.Maps; import kafka.serializer.StringDecoder; + +import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.function.Function; import org.apache.spark.storage.StorageLevel; +import org.apache.spark.streaming.Duration; import org.apache.spark.streaming.LocalJavaStreamingContext; +import org.apache.spark.streaming.api.java.JavaDStream; +import org.apache.spark.streaming.api.java.JavaPairDStream; +import org.apache.spark.streaming.api.java.JavaStreamingContext; + +import org.junit.Test; +import org.junit.After; +import org.junit.Before; + +public class JavaKafkaStreamSuite extends LocalJavaStreamingContext implements Serializable { + private transient KafkaStreamSuite testSuite = new KafkaStreamSuite(); + + @Before + @Override + public void setUp() { + testSuite.beforeFunction(); + System.clearProperty("spark.driver.port"); + //System.setProperty("spark.streaming.clock", "org.apache.spark.streaming.util.SystemClock"); + ssc = new JavaStreamingContext("local[2]", "test", new Duration(1000)); + } + + @After + @Override + public void tearDown() { + ssc.stop(); + ssc = null; + System.clearProperty("spark.driver.port"); + testSuite.afterFunction(); + } -public class JavaKafkaStreamSuite extends LocalJavaStreamingContext { @Test - public void testKafkaStream() { - HashMap topics = Maps.newHashMap(); - - // tests the API, does not actually test data receiving - JavaPairReceiverInputDStream test1 = - KafkaUtils.createStream(ssc, "localhost:12345", "group", topics); - JavaPairReceiverInputDStream test2 = KafkaUtils.createStream(ssc, "localhost:12345", "group", topics, - StorageLevel.MEMORY_AND_DISK_SER_2()); - - HashMap kafkaParams = Maps.newHashMap(); - kafkaParams.put("zookeeper.connect", "localhost:12345"); - kafkaParams.put("group.id","consumer-group"); - JavaPairReceiverInputDStream test3 = KafkaUtils.createStream(ssc, - String.class, String.class, StringDecoder.class, StringDecoder.class, - kafkaParams, topics, StorageLevel.MEMORY_AND_DISK_SER_2()); + public void testKafkaStream() throws InterruptedException { + String topic = "topic1"; + HashMap topics = new HashMap(); + topics.put(topic, 1); + + HashMap sent = new HashMap(); + sent.put("a", 5); + sent.put("b", 3); + sent.put("c", 10); + + testSuite.createTopic(topic); + HashMap tmp = new HashMap(sent); + testSuite.produceAndSendMessage(topic, + JavaConverters.mapAsScalaMapConverter(tmp).asScala().toMap( + Predef.>conforms())); + + HashMap kafkaParams = new HashMap(); + kafkaParams.put("zookeeper.connect", testSuite.zkConnect()); + kafkaParams.put("group.id", "test-consumer-" + KafkaTestUtils.random().nextInt(10000)); + kafkaParams.put("auto.offset.reset", "smallest"); + + JavaPairDStream stream = KafkaUtils.createStream(ssc, + String.class, + String.class, + StringDecoder.class, + StringDecoder.class, + kafkaParams, + topics, + StorageLevel.MEMORY_ONLY_SER()); + + final HashMap result = new HashMap(); + + JavaDStream words = stream.map( + new Function, String>() { + @Override + public String call(Tuple2 tuple2) throws Exception { + return tuple2._2(); + } + } + ); + + words.countByValue().foreachRDD( + new Function, Void>() { + @Override + public Void call(JavaPairRDD rdd) throws Exception { + List> ret = rdd.collect(); + for (Tuple2 r : ret) { + if (result.containsKey(r._1())) { + result.put(r._1(), result.get(r._1()) + r._2()); + } else { + result.put(r._1(), r._2()); + } + } + + return null; + } + } + ); + + ssc.start(); + ssc.awaitTermination(3000); + + Assert.assertEquals(sent.size(), result.size()); + for (String k : sent.keySet()) { + Assert.assertEquals(sent.get(k).intValue(), result.get(k).intValue()); + } } } diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala index e6f2c4a5cf5d1..c0b55e9340253 100644 --- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala +++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala @@ -17,28 +17,193 @@ package org.apache.spark.streaming.kafka -import kafka.serializer.StringDecoder +import java.io.File +import java.net.InetSocketAddress +import java.util.{Properties, Random} + +import scala.collection.mutable + +import kafka.admin.CreateTopicCommand +import kafka.common.TopicAndPartition +import kafka.producer.{KeyedMessage, ProducerConfig, Producer} +import kafka.utils.ZKStringSerializer +import kafka.serializer.{StringDecoder, StringEncoder} +import kafka.server.{KafkaConfig, KafkaServer} + +import org.I0Itec.zkclient.ZkClient + +import org.apache.zookeeper.server.ZooKeeperServer +import org.apache.zookeeper.server.NIOServerCnxnFactory + import org.apache.spark.streaming.{StreamingContext, TestSuiteBase} import org.apache.spark.storage.StorageLevel -import org.apache.spark.streaming.dstream.ReceiverInputDStream +import org.apache.spark.util.Utils class KafkaStreamSuite extends TestSuiteBase { + import KafkaTestUtils._ + + val zkConnect = "localhost:2181" + val zkConnectionTimeout = 6000 + val zkSessionTimeout = 6000 + + val brokerPort = 9092 + val brokerProps = getBrokerConfig(brokerPort, zkConnect) + val brokerConf = new KafkaConfig(brokerProps) + + protected var zookeeper: EmbeddedZookeeper = _ + protected var zkClient: ZkClient = _ + protected var server: KafkaServer = _ + protected var producer: Producer[String, String] = _ + + override def useManualClock = false + + override def beforeFunction() { + // Zookeeper server startup + zookeeper = new EmbeddedZookeeper(zkConnect) + logInfo("==================== 0 ====================") + zkClient = new ZkClient(zkConnect, zkSessionTimeout, zkConnectionTimeout, ZKStringSerializer) + logInfo("==================== 1 ====================") - test("kafka input stream") { + // Kafka broker startup + server = new KafkaServer(brokerConf) + logInfo("==================== 2 ====================") + server.startup() + logInfo("==================== 3 ====================") + Thread.sleep(2000) + logInfo("==================== 4 ====================") + super.beforeFunction() + } + + override def afterFunction() { + producer.close() + server.shutdown() + brokerConf.logDirs.foreach { f => Utils.deleteRecursively(new File(f)) } + + zkClient.close() + zookeeper.shutdown() + + super.afterFunction() + } + + test("Kafka input stream") { val ssc = new StreamingContext(master, framework, batchDuration) - val topics = Map("my-topic" -> 1) - - // tests the API, does not actually test data receiving - val test1: ReceiverInputDStream[(String, String)] = - KafkaUtils.createStream(ssc, "localhost:1234", "group", topics) - val test2: ReceiverInputDStream[(String, String)] = - KafkaUtils.createStream(ssc, "localhost:12345", "group", topics, StorageLevel.MEMORY_AND_DISK_SER_2) - val kafkaParams = Map("zookeeper.connect"->"localhost:12345","group.id"->"consumer-group") - val test3: ReceiverInputDStream[(String, String)] = - KafkaUtils.createStream[String, String, StringDecoder, StringDecoder]( - ssc, kafkaParams, topics, StorageLevel.MEMORY_AND_DISK_SER_2) - - // TODO: Actually test receiving data + val topic = "topic1" + val sent = Map("a" -> 5, "b" -> 3, "c" -> 10) + createTopic(topic) + produceAndSendMessage(topic, sent) + + val kafkaParams = Map("zookeeper.connect" -> zkConnect, + "group.id" -> s"test-consumer-${random.nextInt(10000)}", + "auto.offset.reset" -> "smallest") + + val stream = KafkaUtils.createStream[String, String, StringDecoder, StringDecoder]( + ssc, + kafkaParams, + Map(topic -> 1), + StorageLevel.MEMORY_ONLY) + val result = new mutable.HashMap[String, Long]() + stream.map { case (k, v) => v } + .countByValue() + .foreachRDD { r => + val ret = r.collect() + ret.toMap.foreach { kv => + val count = result.getOrElseUpdate(kv._1, 0) + kv._2 + result.put(kv._1, count) + } + } + ssc.start() + ssc.awaitTermination(3000) + + assert(sent.size === result.size) + sent.keys.foreach { k => assert(sent(k) === result(k).toInt) } + ssc.stop() } + + private def createTestMessage(topic: String, sent: Map[String, Int]) + : Seq[KeyedMessage[String, String]] = { + val messages = for ((s, freq) <- sent; i <- 0 until freq) yield { + new KeyedMessage[String, String](topic, s) + } + messages.toSeq + } + + def createTopic(topic: String) { + CreateTopicCommand.createTopic(zkClient, topic, 1, 1, "0") + logInfo("==================== 5 ====================") + // wait until metadata is propagated + waitUntilMetadataIsPropagated(Seq(server), topic, 0, 1000) + } + + def produceAndSendMessage(topic: String, sent: Map[String, Int]) { + val brokerAddr = brokerConf.hostName + ":" + brokerConf.port + producer = new Producer[String, String](new ProducerConfig(getProducerConfig(brokerAddr))) + producer.send(createTestMessage(topic, sent): _*) + logInfo("==================== 6 ====================") + } +} + +object KafkaTestUtils { + val random = new Random() + + def getBrokerConfig(port: Int, zkConnect: String): Properties = { + val props = new Properties() + props.put("broker.id", "0") + props.put("host.name", "localhost") + props.put("port", port.toString) + props.put("log.dir", Utils.createTempDir().getAbsolutePath) + props.put("zookeeper.connect", zkConnect) + props.put("log.flush.interval.messages", "1") + props.put("replica.socket.timeout.ms", "1500") + props + } + + def getProducerConfig(brokerList: String): Properties = { + val props = new Properties() + props.put("metadata.broker.list", brokerList) + props.put("serializer.class", classOf[StringEncoder].getName) + props + } + + def waitUntilTrue(condition: () => Boolean, waitTime: Long): Boolean = { + val startTime = System.currentTimeMillis() + while (true) { + if (condition()) + return true + if (System.currentTimeMillis() > startTime + waitTime) + return false + Thread.sleep(waitTime.min(100L)) + } + // Should never go to here + throw new RuntimeException("unexpected error") + } + + def waitUntilMetadataIsPropagated(servers: Seq[KafkaServer], topic: String, partition: Int, + timeout: Long) { + assert(waitUntilTrue(() => + servers.foldLeft(true)(_ && _.apis.leaderCache.keySet.contains( + TopicAndPartition(topic, partition))), timeout), + s"Partition [$topic, $partition] metadata not propagated after timeout") + } + + class EmbeddedZookeeper(val zkConnect: String) { + val random = new Random() + val snapshotDir = Utils.createTempDir() + val logDir = Utils.createTempDir() + + val zookeeper = new ZooKeeperServer(snapshotDir, logDir, 500) + val (ip, port) = { + val splits = zkConnect.split(":") + (splits(0), splits(1).toInt) + } + val factory = new NIOServerCnxnFactory() + factory.configure(new InetSocketAddress(ip, port), 16) + factory.startup(zookeeper) + + def shutdown() { + factory.shutdown() + Utils.deleteRecursively(snapshotDir) + Utils.deleteRecursively(logDir) + } + } } From 6c0c65fc85677ab2cae819a546ea50ed660994c3 Mon Sep 17 00:00:00 2001 From: Thomas Graves Date: Tue, 5 Aug 2014 12:48:26 -0500 Subject: [PATCH 034/231] SPARK-1528 - spark on yarn, add support for accessing remote HDFS Add a config (spark.yarn.access.namenodes) to allow applications running on yarn to access other secure HDFS cluster. User just specifies the namenodes of the other clusters and we get Tokens for those and ship them with the spark application. Author: Thomas Graves Closes #1159 from tgravescs/spark-1528 and squashes the following commits: ddbcd16 [Thomas Graves] review comments 0ac8501 [Thomas Graves] SPARK-1528 - add support for accessing remote HDFS (cherry picked from commit 2c0f705e26ca3dfc43a1e9a0722c0e57f67c970a) Signed-off-by: Thomas Graves --- docs/running-on-yarn.md | 7 +++ .../apache/spark/deploy/yarn/ClientBase.scala | 56 +++++++++++++------ .../spark/deploy/yarn/ClientBaseSuite.scala | 55 +++++++++++++++++- 3 files changed, 101 insertions(+), 17 deletions(-) diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index 0362f5a223319..573930dbf4e54 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -106,6 +106,13 @@ Most of the configs are the same for Spark on YARN as for other deployment modes set this configuration to "hdfs:///some/path". + + spark.yarn.access.namenodes + (none) + + A list of secure HDFS namenodes your Spark application is going to access. For example, `spark.yarn.access.namenodes=hdfs://nn1.com:8032,hdfs://nn2.com:8032`. The Spark application must have acess to the namenodes listed and Kerberos must be properly configured to be able to access them (either in the same realm or in a trusted realm). Spark acquires security tokens for each of the namenodes so that the Spark application can access those remote HDFS clusters. + + # Launching Spark on YARN diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala index b7e8636e02eb2..ed8f56ab8b75e 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala @@ -29,7 +29,7 @@ import org.apache.hadoop.fs._ import org.apache.hadoop.fs.permission.FsPermission import org.apache.hadoop.mapred.Master import org.apache.hadoop.mapreduce.MRJobConfig -import org.apache.hadoop.security.UserGroupInformation +import org.apache.hadoop.security.{Credentials, UserGroupInformation} import org.apache.hadoop.util.StringUtils import org.apache.hadoop.yarn.api._ import org.apache.hadoop.yarn.api.ApplicationConstants.Environment @@ -191,23 +191,11 @@ trait ClientBase extends Logging { // Upload Spark and the application JAR to the remote file system if necessary. Add them as // local resources to the application master. val fs = FileSystem.get(conf) - - val delegTokenRenewer = Master.getMasterPrincipal(conf) - if (UserGroupInformation.isSecurityEnabled()) { - if (delegTokenRenewer == null || delegTokenRenewer.length() == 0) { - val errorMessage = "Can't get Master Kerberos principal for use as renewer" - logError(errorMessage) - throw new SparkException(errorMessage) - } - } val dst = new Path(fs.getHomeDirectory(), appStagingDir) - val replication = sparkConf.getInt("spark.yarn.submit.file.replication", 3).toShort - - if (UserGroupInformation.isSecurityEnabled()) { - val dstFs = dst.getFileSystem(conf) - dstFs.addDelegationTokens(delegTokenRenewer, credentials) - } + val nns = ClientBase.getNameNodesToAccess(sparkConf) + dst + ClientBase.obtainTokensForNamenodes(nns, conf, credentials) + val replication = sparkConf.getInt("spark.yarn.submit.file.replication", 3).toShort val localResources = HashMap[String, LocalResource]() FileSystem.mkdirs(fs, dst, new FsPermission(STAGING_DIR_PERMISSION)) @@ -614,4 +602,40 @@ object ClientBase extends Logging { YarnSparkHadoopUtil.addToEnvironment(env, Environment.CLASSPATH.name, path, File.pathSeparator) + /** + * Get the list of namenodes the user may access. + */ + private[yarn] def getNameNodesToAccess(sparkConf: SparkConf): Set[Path] = { + sparkConf.get("spark.yarn.access.namenodes", "").split(",").map(_.trim()).filter(!_.isEmpty) + .map(new Path(_)).toSet + } + + private[yarn] def getTokenRenewer(conf: Configuration): String = { + val delegTokenRenewer = Master.getMasterPrincipal(conf) + logDebug("delegation token renewer is: " + delegTokenRenewer) + if (delegTokenRenewer == null || delegTokenRenewer.length() == 0) { + val errorMessage = "Can't get Master Kerberos principal for use as renewer" + logError(errorMessage) + throw new SparkException(errorMessage) + } + delegTokenRenewer + } + + /** + * Obtains tokens for the namenodes passed in and adds them to the credentials. + */ + private[yarn] def obtainTokensForNamenodes(paths: Set[Path], conf: Configuration, + creds: Credentials) { + if (UserGroupInformation.isSecurityEnabled()) { + val delegTokenRenewer = getTokenRenewer(conf) + + paths.foreach { + dst => + val dstFs = dst.getFileSystem(conf) + logDebug("getting token for namenode: " + dst) + dstFs.addDelegationTokens(delegTokenRenewer, creds) + } + } + } + } diff --git a/yarn/common/src/test/scala/org/apache/spark/deploy/yarn/ClientBaseSuite.scala b/yarn/common/src/test/scala/org/apache/spark/deploy/yarn/ClientBaseSuite.scala index 686714dc36488..68cc2890f3a22 100644 --- a/yarn/common/src/test/scala/org/apache/spark/deploy/yarn/ClientBaseSuite.scala +++ b/yarn/common/src/test/scala/org/apache/spark/deploy/yarn/ClientBaseSuite.scala @@ -31,6 +31,8 @@ import org.apache.hadoop.yarn.api.records.ContainerLaunchContext import org.apache.hadoop.yarn.conf.YarnConfiguration import org.mockito.Matchers._ import org.mockito.Mockito._ + + import org.scalatest.FunSuite import org.scalatest.Matchers @@ -38,7 +40,7 @@ import scala.collection.JavaConversions._ import scala.collection.mutable.{ HashMap => MutableHashMap } import scala.util.Try -import org.apache.spark.SparkConf +import org.apache.spark.{SparkException, SparkConf} import org.apache.spark.util.Utils class ClientBaseSuite extends FunSuite with Matchers { @@ -138,6 +140,57 @@ class ClientBaseSuite extends FunSuite with Matchers { } } + test("check access nns empty") { + val sparkConf = new SparkConf() + sparkConf.set("spark.yarn.access.namenodes", "") + val nns = ClientBase.getNameNodesToAccess(sparkConf) + nns should be(Set()) + } + + test("check access nns unset") { + val sparkConf = new SparkConf() + val nns = ClientBase.getNameNodesToAccess(sparkConf) + nns should be(Set()) + } + + test("check access nns") { + val sparkConf = new SparkConf() + sparkConf.set("spark.yarn.access.namenodes", "hdfs://nn1:8032") + val nns = ClientBase.getNameNodesToAccess(sparkConf) + nns should be(Set(new Path("hdfs://nn1:8032"))) + } + + test("check access nns space") { + val sparkConf = new SparkConf() + sparkConf.set("spark.yarn.access.namenodes", "hdfs://nn1:8032, ") + val nns = ClientBase.getNameNodesToAccess(sparkConf) + nns should be(Set(new Path("hdfs://nn1:8032"))) + } + + test("check access two nns") { + val sparkConf = new SparkConf() + sparkConf.set("spark.yarn.access.namenodes", "hdfs://nn1:8032,hdfs://nn2:8032") + val nns = ClientBase.getNameNodesToAccess(sparkConf) + nns should be(Set(new Path("hdfs://nn1:8032"), new Path("hdfs://nn2:8032"))) + } + + test("check token renewer") { + val hadoopConf = new Configuration() + hadoopConf.set("yarn.resourcemanager.address", "myrm:8033") + hadoopConf.set("yarn.resourcemanager.principal", "yarn/myrm:8032@SPARKTEST.COM") + val renewer = ClientBase.getTokenRenewer(hadoopConf) + renewer should be ("yarn/myrm:8032@SPARKTEST.COM") + } + + test("check token renewer default") { + val hadoopConf = new Configuration() + val caught = + intercept[SparkException] { + ClientBase.getTokenRenewer(hadoopConf) + } + assert(caught.getMessage === "Can't get Master Kerberos principal for use as renewer") + } + object Fixtures { val knownDefYarnAppCP: Seq[String] = From e3fe6571decfdc406ec6d505fd92f9f2b85a618c Mon Sep 17 00:00:00 2001 From: Thomas Graves Date: Tue, 5 Aug 2014 12:52:52 -0500 Subject: [PATCH 035/231] SPARK-1890 and SPARK-1891- add admin and modify acls It was easier to combine these 2 jira since they touch many of the same places. This pr adds the following: - adds modify acls - adds admin acls (list of admins/users that get added to both view and modify acls) - modify Kill button on UI to take modify acls into account - changes config name of spark.ui.acls.enable to spark.acls.enable since I choose poorly in original name. We keep backwards compatibility so people can still use spark.ui.acls.enable. The acls should apply to any web ui as well as any CLI interfaces. - send view and modify acls information on to YARN so that YARN interfaces can use (yarn cli for killing applications for example). Author: Thomas Graves Closes #1196 from tgravescs/SPARK-1890 and squashes the following commits: 8292eb1 [Thomas Graves] review comments b92ec89 [Thomas Graves] remove unneeded variable from applistener 4c765f4 [Thomas Graves] Add in admin acls 72eb0ac [Thomas Graves] Add modify acls (cherry picked from commit 1c5555a23d3aa40423d658cfbf2c956ad415a6b1) Signed-off-by: Thomas Graves --- .../org/apache/spark/SecurityManager.scala | 107 +++++++++++++++--- .../deploy/history/FsHistoryProvider.scala | 4 +- .../scheduler/ApplicationEventListener.scala | 4 +- .../apache/spark/ui/jobs/JobProgressTab.scala | 2 +- .../apache/spark/SecurityManagerSuite.scala | 83 ++++++++++++-- docs/configuration.md | 27 ++++- docs/security.md | 7 +- .../apache/spark/deploy/yarn/ClientBase.scala | 9 +- 8 files changed, 206 insertions(+), 37 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala index 74aa441619bd2..25c2c9fc6af7c 100644 --- a/core/src/main/scala/org/apache/spark/SecurityManager.scala +++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala @@ -41,10 +41,19 @@ import org.apache.spark.deploy.SparkHadoopUtil * secure the UI if it has data that other users should not be allowed to see. The javax * servlet filter specified by the user can authenticate the user and then once the user * is logged in, Spark can compare that user versus the view acls to make sure they are - * authorized to view the UI. The configs 'spark.ui.acls.enable' and 'spark.ui.view.acls' + * authorized to view the UI. The configs 'spark.acls.enable' and 'spark.ui.view.acls' * control the behavior of the acls. Note that the person who started the application * always has view access to the UI. * + * Spark has a set of modify acls (`spark.modify.acls`) that controls which users have permission + * to modify a single application. This would include things like killing the application. By + * default the person who started the application has modify access. For modify access through + * the UI, you must have a filter that does authentication in place for the modify acls to work + * properly. + * + * Spark also has a set of admin acls (`spark.admin.acls`) which is a set of users/administrators + * who always have permission to view or modify the Spark application. + * * Spark does not currently support encryption after authentication. * * At this point spark has multiple communication protocols that need to be secured and @@ -137,18 +146,32 @@ private[spark] class SecurityManager(sparkConf: SparkConf) extends Logging { private val sparkSecretLookupKey = "sparkCookie" private val authOn = sparkConf.getBoolean("spark.authenticate", false) - private var uiAclsOn = sparkConf.getBoolean("spark.ui.acls.enable", false) + // keep spark.ui.acls.enable for backwards compatibility with 1.0 + private var aclsOn = sparkConf.getOption("spark.acls.enable").getOrElse( + sparkConf.get("spark.ui.acls.enable", "false")).toBoolean + + // admin acls should be set before view or modify acls + private var adminAcls: Set[String] = + stringToSet(sparkConf.get("spark.admin.acls", "")) private var viewAcls: Set[String] = _ + + // list of users who have permission to modify the application. This should + // apply to both UI and CLI for things like killing the application. + private var modifyAcls: Set[String] = _ + // always add the current user and SPARK_USER to the viewAcls - private val defaultAclUsers = Seq[String](System.getProperty("user.name", ""), + private val defaultAclUsers = Set[String](System.getProperty("user.name", ""), Option(System.getenv("SPARK_USER")).getOrElse("")) + setViewAcls(defaultAclUsers, sparkConf.get("spark.ui.view.acls", "")) + setModifyAcls(defaultAclUsers, sparkConf.get("spark.modify.acls", "")) private val secretKey = generateSecretKey() logInfo("SecurityManager: authentication " + (if (authOn) "enabled" else "disabled") + - "; ui acls " + (if (uiAclsOn) "enabled" else "disabled") + - "; users with view permissions: " + viewAcls.toString()) + "; ui acls " + (if (aclsOn) "enabled" else "disabled") + + "; users with view permissions: " + viewAcls.toString() + + "; users with modify permissions: " + modifyAcls.toString()) // Set our own authenticator to properly negotiate user/password for HTTP connections. // This is needed by the HTTP client fetching from the HttpServer. Put here so its @@ -169,18 +192,51 @@ private[spark] class SecurityManager(sparkConf: SparkConf) extends Logging { ) } - private[spark] def setViewAcls(defaultUsers: Seq[String], allowedUsers: String) { - viewAcls = (defaultUsers ++ allowedUsers.split(',')).map(_.trim()).filter(!_.isEmpty).toSet + /** + * Split a comma separated String, filter out any empty items, and return a Set of strings + */ + private def stringToSet(list: String): Set[String] = { + list.split(',').map(_.trim).filter(!_.isEmpty).toSet + } + + /** + * Admin acls should be set before the view or modify acls. If you modify the admin + * acls you should also set the view and modify acls again to pick up the changes. + */ + def setViewAcls(defaultUsers: Set[String], allowedUsers: String) { + viewAcls = (adminAcls ++ defaultUsers ++ stringToSet(allowedUsers)) logInfo("Changing view acls to: " + viewAcls.mkString(",")) } - private[spark] def setViewAcls(defaultUser: String, allowedUsers: String) { - setViewAcls(Seq[String](defaultUser), allowedUsers) + def setViewAcls(defaultUser: String, allowedUsers: String) { + setViewAcls(Set[String](defaultUser), allowedUsers) + } + + def getViewAcls: String = viewAcls.mkString(",") + + /** + * Admin acls should be set before the view or modify acls. If you modify the admin + * acls you should also set the view and modify acls again to pick up the changes. + */ + def setModifyAcls(defaultUsers: Set[String], allowedUsers: String) { + modifyAcls = (adminAcls ++ defaultUsers ++ stringToSet(allowedUsers)) + logInfo("Changing modify acls to: " + modifyAcls.mkString(",")) + } + + def getModifyAcls: String = modifyAcls.mkString(",") + + /** + * Admin acls should be set before the view or modify acls. If you modify the admin + * acls you should also set the view and modify acls again to pick up the changes. + */ + def setAdminAcls(adminUsers: String) { + adminAcls = stringToSet(adminUsers) + logInfo("Changing admin acls to: " + adminAcls.mkString(",")) } - private[spark] def setUIAcls(aclSetting: Boolean) { - uiAclsOn = aclSetting - logInfo("Changing acls enabled to: " + uiAclsOn) + def setAcls(aclSetting: Boolean) { + aclsOn = aclSetting + logInfo("Changing acls enabled to: " + aclsOn) } /** @@ -224,22 +280,39 @@ private[spark] class SecurityManager(sparkConf: SparkConf) extends Logging { * Check to see if Acls for the UI are enabled * @return true if UI authentication is enabled, otherwise false */ - def uiAclsEnabled(): Boolean = uiAclsOn + def aclsEnabled(): Boolean = aclsOn /** * Checks the given user against the view acl list to see if they have - * authorization to view the UI. If the UI acls must are disabled - * via spark.ui.acls.enable, all users have view access. + * authorization to view the UI. If the UI acls are disabled + * via spark.acls.enable, all users have view access. If the user is null + * it is assumed authentication is off and all users have access. * * @param user to see if is authorized * @return true is the user has permission, otherwise false */ def checkUIViewPermissions(user: String): Boolean = { - logDebug("user=" + user + " uiAclsEnabled=" + uiAclsEnabled() + " viewAcls=" + + logDebug("user=" + user + " aclsEnabled=" + aclsEnabled() + " viewAcls=" + viewAcls.mkString(",")) - if (uiAclsEnabled() && (user != null) && (!viewAcls.contains(user))) false else true + if (aclsEnabled() && (user != null) && (!viewAcls.contains(user))) false else true } + /** + * Checks the given user against the modify acl list to see if they have + * authorization to modify the application. If the UI acls are disabled + * via spark.acls.enable, all users have modify access. If the user is null + * it is assumed authentication isn't turned on and all users have access. + * + * @param user to see if is authorized + * @return true is the user has permission, otherwise false + */ + def checkModifyPermissions(user: String): Boolean = { + logDebug("user=" + user + " aclsEnabled=" + aclsEnabled() + " modifyAcls=" + + modifyAcls.mkString(",")) + if (aclsEnabled() && (user != null) && (!modifyAcls.contains(user))) false else true + } + + /** * Check to see if authentication for the Spark communication protocols is enabled * @return true if authentication is enabled, otherwise false diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index 6d2d4cef1ee46..cc06540ee0647 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -189,7 +189,9 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis if (ui != null) { val uiAclsEnabled = conf.getBoolean("spark.history.ui.acls.enable", false) - ui.getSecurityManager.setUIAcls(uiAclsEnabled) + ui.getSecurityManager.setAcls(uiAclsEnabled) + // make sure to set admin acls before view acls so properly picked up + ui.getSecurityManager.setAdminAcls(appListener.adminAcls) ui.getSecurityManager.setViewAcls(appListener.sparkUser, appListener.viewAcls) } (appInfo, ui) diff --git a/core/src/main/scala/org/apache/spark/scheduler/ApplicationEventListener.scala b/core/src/main/scala/org/apache/spark/scheduler/ApplicationEventListener.scala index cd5d44ad4a7e6..162158babc35b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ApplicationEventListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ApplicationEventListener.scala @@ -29,7 +29,7 @@ private[spark] class ApplicationEventListener extends SparkListener { var startTime = -1L var endTime = -1L var viewAcls = "" - var enableViewAcls = false + var adminAcls = "" def applicationStarted = startTime != -1 @@ -55,7 +55,7 @@ private[spark] class ApplicationEventListener extends SparkListener { val environmentDetails = environmentUpdate.environmentDetails val allProperties = environmentDetails("Spark Properties").toMap viewAcls = allProperties.getOrElse("spark.ui.view.acls", "") - enableViewAcls = allProperties.getOrElse("spark.ui.acls.enable", "false").toBoolean + adminAcls = allProperties.getOrElse("spark.admin.acls", "") } } } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressTab.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressTab.scala index 3308c8c8a3d37..8a01ec80c9dd6 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressTab.scala @@ -41,7 +41,7 @@ private[ui] class JobProgressTab(parent: SparkUI) extends WebUITab(parent, "stag def isFairScheduler = listener.schedulingMode.exists(_ == SchedulingMode.FAIR) def handleKillRequest(request: HttpServletRequest) = { - if (killEnabled) { + if ((killEnabled) && (parent.securityManager.checkModifyPermissions(request.getRemoteUser))) { val killFlag = Option(request.getParameter("terminate")).getOrElse("false").toBoolean val stageId = Option(request.getParameter("id")).getOrElse("-1").toInt if (stageId >= 0 && killFlag && listener.activeStages.contains(stageId)) { diff --git a/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala b/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala index e39093e24d68a..fcca0867b8072 100644 --- a/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala @@ -31,7 +31,7 @@ class SecurityManagerSuite extends FunSuite { conf.set("spark.ui.view.acls", "user1,user2") val securityManager = new SecurityManager(conf); assert(securityManager.isAuthenticationEnabled() === true) - assert(securityManager.uiAclsEnabled() === true) + assert(securityManager.aclsEnabled() === true) assert(securityManager.checkUIViewPermissions("user1") === true) assert(securityManager.checkUIViewPermissions("user2") === true) assert(securityManager.checkUIViewPermissions("user3") === false) @@ -41,17 +41,17 @@ class SecurityManagerSuite extends FunSuite { val conf = new SparkConf conf.set("spark.ui.view.acls", "user1,user2") val securityManager = new SecurityManager(conf); - securityManager.setUIAcls(true) - assert(securityManager.uiAclsEnabled() === true) - securityManager.setUIAcls(false) - assert(securityManager.uiAclsEnabled() === false) + securityManager.setAcls(true) + assert(securityManager.aclsEnabled() === true) + securityManager.setAcls(false) + assert(securityManager.aclsEnabled() === false) // acls are off so doesn't matter what view acls set to assert(securityManager.checkUIViewPermissions("user4") === true) - securityManager.setUIAcls(true) - assert(securityManager.uiAclsEnabled() === true) - securityManager.setViewAcls(ArrayBuffer[String]("user5"), "user6,user7") + securityManager.setAcls(true) + assert(securityManager.aclsEnabled() === true) + securityManager.setViewAcls(Set[String]("user5"), "user6,user7") assert(securityManager.checkUIViewPermissions("user1") === false) assert(securityManager.checkUIViewPermissions("user5") === true) assert(securityManager.checkUIViewPermissions("user6") === true) @@ -59,5 +59,72 @@ class SecurityManagerSuite extends FunSuite { assert(securityManager.checkUIViewPermissions("user8") === false) assert(securityManager.checkUIViewPermissions(null) === true) } + + test("set security modify acls") { + val conf = new SparkConf + conf.set("spark.modify.acls", "user1,user2") + + val securityManager = new SecurityManager(conf); + securityManager.setAcls(true) + assert(securityManager.aclsEnabled() === true) + securityManager.setAcls(false) + assert(securityManager.aclsEnabled() === false) + + // acls are off so doesn't matter what view acls set to + assert(securityManager.checkModifyPermissions("user4") === true) + + securityManager.setAcls(true) + assert(securityManager.aclsEnabled() === true) + securityManager.setModifyAcls(Set("user5"), "user6,user7") + assert(securityManager.checkModifyPermissions("user1") === false) + assert(securityManager.checkModifyPermissions("user5") === true) + assert(securityManager.checkModifyPermissions("user6") === true) + assert(securityManager.checkModifyPermissions("user7") === true) + assert(securityManager.checkModifyPermissions("user8") === false) + assert(securityManager.checkModifyPermissions(null) === true) + } + + test("set security admin acls") { + val conf = new SparkConf + conf.set("spark.admin.acls", "user1,user2") + conf.set("spark.ui.view.acls", "user3") + conf.set("spark.modify.acls", "user4") + + val securityManager = new SecurityManager(conf); + securityManager.setAcls(true) + assert(securityManager.aclsEnabled() === true) + + assert(securityManager.checkModifyPermissions("user1") === true) + assert(securityManager.checkModifyPermissions("user2") === true) + assert(securityManager.checkModifyPermissions("user4") === true) + assert(securityManager.checkModifyPermissions("user3") === false) + assert(securityManager.checkModifyPermissions("user5") === false) + assert(securityManager.checkModifyPermissions(null) === true) + assert(securityManager.checkUIViewPermissions("user1") === true) + assert(securityManager.checkUIViewPermissions("user2") === true) + assert(securityManager.checkUIViewPermissions("user3") === true) + assert(securityManager.checkUIViewPermissions("user4") === false) + assert(securityManager.checkUIViewPermissions("user5") === false) + assert(securityManager.checkUIViewPermissions(null) === true) + + securityManager.setAdminAcls("user6") + securityManager.setViewAcls(Set[String]("user8"), "user9") + securityManager.setModifyAcls(Set("user11"), "user9") + assert(securityManager.checkModifyPermissions("user6") === true) + assert(securityManager.checkModifyPermissions("user11") === true) + assert(securityManager.checkModifyPermissions("user9") === true) + assert(securityManager.checkModifyPermissions("user1") === false) + assert(securityManager.checkModifyPermissions("user4") === false) + assert(securityManager.checkModifyPermissions(null) === true) + assert(securityManager.checkUIViewPermissions("user6") === true) + assert(securityManager.checkUIViewPermissions("user8") === true) + assert(securityManager.checkUIViewPermissions("user9") === true) + assert(securityManager.checkUIViewPermissions("user1") === false) + assert(securityManager.checkUIViewPermissions("user3") === false) + assert(securityManager.checkUIViewPermissions(null) === true) + + } + + } diff --git a/docs/configuration.md b/docs/configuration.md index 870343f1c0bd2..13334657a2107 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -815,13 +815,13 @@ Apart from these, the following properties are also available, and may be useful - spark.ui.acls.enable + spark.acls.enable false - Whether Spark web ui acls should are enabled. If enabled, this checks to see if the user has - access permissions to view the web ui. See spark.ui.view.acls for more details. - Also note this requires the user to be known, if the user comes across as null no checks - are done. Filters can be used to authenticate and set the user. + Whether Spark acls should are enabled. If enabled, this checks to see if the user has + access permissions to view or modify the job. Note this requires the user to be known, + so if the user comes across as null no checks are done. Filters can be used with the UI + to authenticate and set the user. @@ -832,6 +832,23 @@ Apart from these, the following properties are also available, and may be useful user that started the Spark job has view access. + + spark.modify.acls + Empty + + Comma separated list of users that have modify access to the Spark job. By default only the + user that started the Spark job has access to modify it (kill it for example). + + + + spark.admin.acls + Empty + + Comma separated list of users/administrators that have view and modify access to all Spark jobs. + This can be used if you run on a shared cluster and have a set of administrators or devs who + help debug when things work. + + #### Spark Streaming diff --git a/docs/security.md b/docs/security.md index 90ba678033b19..8312f8d017e1f 100644 --- a/docs/security.md +++ b/docs/security.md @@ -8,8 +8,11 @@ Spark currently supports authentication via a shared secret. Authentication can * For Spark on [YARN](running-on-yarn.html) deployments, configuring `spark.authenticate` to `true` will automatically handle generating and distributing the shared secret. Each application will use a unique shared secret. * For other types of Spark deployments, the Spark parameter `spark.authenticate.secret` should be configured on each of the nodes. This secret will be used by all the Master/Workers and applications. -The Spark UI can also be secured by using [javax servlet filters](http://docs.oracle.com/javaee/6/api/javax/servlet/Filter.html) via the `spark.ui.filters` setting. A user may want to secure the UI if it has data that other users should not be allowed to see. The javax servlet filter specified by the user can authenticate the user and then once the user is logged in, Spark can compare that user versus the view ACLs to make sure they are authorized to view the UI. The configs `spark.ui.acls.enable` and `spark.ui.view.acls` control the behavior of the ACLs. Note that the user who started the application always has view access to the UI. -On YARN, the Spark UI uses the standard YARN web application proxy mechanism and will authenticate via any installed Hadoop filters. +The Spark UI can also be secured by using [javax servlet filters](http://docs.oracle.com/javaee/6/api/javax/servlet/Filter.html) via the `spark.ui.filters` setting. A user may want to secure the UI if it has data that other users should not be allowed to see. The javax servlet filter specified by the user can authenticate the user and then once the user is logged in, Spark can compare that user versus the view ACLs to make sure they are authorized to view the UI. The configs `spark.acls.enable` and `spark.ui.view.acls` control the behavior of the ACLs. Note that the user who started the application always has view access to the UI. On YARN, the Spark UI uses the standard YARN web application proxy mechanism and will authenticate via any installed Hadoop filters. + +Spark also supports modify ACLs to control who has access to modify a running Spark application. This includes things like killing the application or a task. This is controlled by the configs `spark.acls.enable` and `spark.modify.acls`. Note that if you are authenticating the web UI, in order to use the kill button on the web UI it might be necessary to add the users in the modify acls to the view acls also. On YARN, the modify acls are passed in and control who has modify access via YARN interfaces. + +Spark allows for a set of administrators to be specified in the acls who always have view and modify permissions to all the applications. is controlled by the config `spark.admin.acls`. This is useful on a shared cluster where you might have administrators or support staff who help users debug applications. If your applications are using event logging, the directory where the event logs go (`spark.eventLog.dir`) should be manually created and have the proper permissions set on it. If you want those log files secured, the permissions should be set to `drwxrwxrwxt` for that directory. The owner of the directory should be the super user who is running the history server and the group permissions should be restricted to super user group. This will allow all users to write to the directory but will prevent unprivileged users from removing or renaming a file unless they own the file or directory. The event log files will be created by Spark with permissions such that only the user and group have read and write access. diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala index ed8f56ab8b75e..44e025b8f60ba 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala @@ -37,7 +37,7 @@ import org.apache.hadoop.yarn.api.protocolrecords._ import org.apache.hadoop.yarn.api.records._ import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.hadoop.yarn.util.Records -import org.apache.spark.{SparkException, Logging, SparkConf, SparkContext} +import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkContext, SparkException} /** * The entry point (starting in Client#main() and Client#run()) for launching Spark on YARN. The @@ -405,6 +405,13 @@ trait ClientBase extends Logging { amContainer.setCommands(printableCommands) setupSecurityToken(amContainer) + + // send the acl settings into YARN to control who has access via YARN interfaces + val securityManager = new SecurityManager(sparkConf) + val acls = Map[ApplicationAccessType, String] ( + ApplicationAccessType.VIEW_APP -> securityManager.getViewAcls, + ApplicationAccessType.MODIFY_APP -> securityManager.getModifyAcls) + amContainer.setApplicationACLs(acls) amContainer } } From 388ab534b318e6736484a2fab6f88390abbf8c55 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Tue, 5 Aug 2014 11:17:50 -0700 Subject: [PATCH 036/231] [SPARK-2860][SQL] Fix coercion of CASE WHEN. Author: Michael Armbrust Closes #1785 from marmbrus/caseNull and squashes the following commits: 126006d [Michael Armbrust] better error message 2fe357f [Michael Armbrust] Fix coercion of CASE WHEN. (cherry picked from commit 6e821e3d1ae1ed23459bc7f1098510b968130152) Signed-off-by: Michael Armbrust --- .../catalyst/analysis/HiveTypeCoercion.scala | 56 +++++++++++-------- ...ll case-0-581cdfe70091e546414b202da2cebdcb | 1 + .../sql/hive/execution/HiveQuerySuite.scala | 3 + 3 files changed, 36 insertions(+), 24 deletions(-) create mode 100644 sql/hive/src/test/resources/golden/null case-0-581cdfe70091e546414b202da2cebdcb diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index e94f2a3bea63e..15eb5982a4a91 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -49,10 +49,21 @@ trait HiveTypeCoercion { BooleanCasts :: StringToIntegralCasts :: FunctionArgumentConversion :: - CastNulls :: + CaseWhenCoercion :: Division :: Nil + trait TypeWidening { + def findTightestCommonType(t1: DataType, t2: DataType): Option[DataType] = { + // Try and find a promotion rule that contains both types in question. + val applicableConversion = + HiveTypeCoercion.allPromotions.find(p => p.contains(t1) && p.contains(t2)) + + // If found return the widest common type, otherwise None + applicableConversion.map(_.filter(t => t == t1 || t == t2).last) + } + } + /** * Applies any changes to [[AttributeReference]] data types that are made by other rules to * instances higher in the query tree. @@ -133,16 +144,7 @@ trait HiveTypeCoercion { * - LongType to FloatType * - LongType to DoubleType */ - object WidenTypes extends Rule[LogicalPlan] { - - def findTightestCommonType(t1: DataType, t2: DataType): Option[DataType] = { - // Try and find a promotion rule that contains both types in question. - val applicableConversion = - HiveTypeCoercion.allPromotions.find(p => p.contains(t1) && p.contains(t2)) - - // If found return the widest common type, otherwise None - applicableConversion.map(_.filter(t => t == t1 || t == t2).last) - } + object WidenTypes extends Rule[LogicalPlan] with TypeWidening { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case u @ Union(left, right) if u.childrenResolved && !u.resolved => @@ -336,28 +338,34 @@ trait HiveTypeCoercion { } /** - * Ensures that NullType gets casted to some other types under certain circumstances. + * Coerces the type of different branches of a CASE WHEN statement to a common type. */ - object CastNulls extends Rule[LogicalPlan] { + object CaseWhenCoercion extends Rule[LogicalPlan] with TypeWidening { def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { - case cw @ CaseWhen(branches) => + case cw @ CaseWhen(branches) if !cw.resolved && !branches.exists(!_.resolved) => val valueTypes = branches.sliding(2, 2).map { - case Seq(_, value) if value.resolved => Some(value.dataType) - case Seq(elseVal) if elseVal.resolved => Some(elseVal.dataType) - case _ => None + case Seq(_, value) => value.dataType + case Seq(elseVal) => elseVal.dataType }.toSeq - if (valueTypes.distinct.size == 2 && valueTypes.exists(_ == Some(NullType))) { - val otherType = valueTypes.filterNot(_ == Some(NullType))(0).get + + logDebug(s"Input values for null casting ${valueTypes.mkString(",")}") + + if (valueTypes.distinct.size > 1) { + val commonType = valueTypes.reduce { (v1, v2) => + findTightestCommonType(v1, v2) + .getOrElse(sys.error( + s"Types in CASE WHEN must be the same or coercible to a common type: $v1 != $v2")) + } val transformedBranches = branches.sliding(2, 2).map { - case Seq(cond, value) if value.resolved && value.dataType == NullType => - Seq(cond, Cast(value, otherType)) - case Seq(elseVal) if elseVal.resolved && elseVal.dataType == NullType => - Seq(Cast(elseVal, otherType)) + case Seq(cond, value) if value.dataType != commonType => + Seq(cond, Cast(value, commonType)) + case Seq(elseVal) if elseVal.dataType != commonType => + Seq(Cast(elseVal, commonType)) case s => s }.reduce(_ ++ _) CaseWhen(transformedBranches) } else { - // It is possible to have more types due to the possibility of short-circuiting. + // Types match up. Hopefully some other rule fixes whatever is wrong with resolution. cw } } diff --git a/sql/hive/src/test/resources/golden/null case-0-581cdfe70091e546414b202da2cebdcb b/sql/hive/src/test/resources/golden/null case-0-581cdfe70091e546414b202da2cebdcb new file mode 100644 index 0000000000000..d00491fd7e5bb --- /dev/null +++ b/sql/hive/src/test/resources/golden/null case-0-581cdfe70091e546414b202da2cebdcb @@ -0,0 +1 @@ +1 diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index aa810a291231a..2f0be49b6a6d7 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -32,6 +32,9 @@ case class TestData(a: Int, b: String) */ class HiveQuerySuite extends HiveComparisonTest { + createQueryTest("null case", + "SELECT case when(true) then 1 else null end FROM src LIMIT 1") + createQueryTest("single case", """SELECT case when true then 1 else 2 end FROM src LIMIT 1""") From 0f541abe74653ffe84381c05105a1a2f92b02da4 Mon Sep 17 00:00:00 2001 From: "Guancheng (G.C.) Chen" Date: Tue, 5 Aug 2014 11:50:08 -0700 Subject: [PATCH 037/231] [SPARK-2859] Update url of Kryo project in related docs JIRA Issue: https://issues.apache.org/jira/browse/SPARK-2859 Kryo project has been migrated from googlecode to github, hence we need to update its URL in related docs such as tuning.md. Author: Guancheng (G.C.) Chen Closes #1782 from gchen/kryo-docs and squashes the following commits: b62543c [Guancheng (G.C.) Chen] update url of Kryo project (cherry picked from commit ac3440f4f3c4b79070ffec7db0b08ad062b4df90) Signed-off-by: Patrick Wendell --- docs/tuning.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/tuning.md b/docs/tuning.md index 4917c11bc1147..8fb2a0433b1a8 100644 --- a/docs/tuning.md +++ b/docs/tuning.md @@ -32,7 +32,7 @@ in your operations) and performance. It provides two serialization libraries: [`java.io.Externalizable`](http://docs.oracle.com/javase/6/docs/api/java/io/Externalizable.html). Java serialization is flexible but often quite slow, and leads to large serialized formats for many classes. -* [Kryo serialization](http://code.google.com/p/kryo/): Spark can also use +* [Kryo serialization](https://github.com/EsotericSoftware/kryo): Spark can also use the Kryo library (version 2) to serialize objects more quickly. Kryo is significantly faster and more compact than Java serialization (often as much as 10x), but does not support all `Serializable` types and requires you to *register* the classes you'll use in the program in advance @@ -68,7 +68,7 @@ conf.set("spark.kryo.registrator", "mypackage.MyRegistrator") val sc = new SparkContext(conf) {% endhighlight %} -The [Kryo documentation](http://code.google.com/p/kryo/) describes more advanced +The [Kryo documentation](https://github.com/EsotericSoftware/kryo) describes more advanced registration options, such as adding custom serialization code. If your objects are large, you may also need to increase the `spark.kryoserializer.buffer.mb` From 46b69830737cc673bfe2f9b2b9f1ced6556b1af1 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Tue, 5 Aug 2014 13:08:23 -0700 Subject: [PATCH 038/231] SPARK-2380: Support displaying accumulator values in the web UI This patch adds support for giving accumulators user-visible names and displaying accumulator values in the web UI. This allows users to create custom counters that can display in the UI. The current approach displays both the accumulator deltas caused by each task and a "current" value of the accumulator totals for each stage, which gets update as tasks finish. Currently in Spark developers have been extending the `TaskMetrics` functionality to provide custom instrumentation for RDD's. This provides a potentially nicer alternative of going through the existing accumulator framework (actually `TaskMetrics` and accumulators are on an awkward collision course as we add more features to the former). The current patch demo's how we can use the feature to provide instrumentation for RDD input sizes. The nice thing about going through accumulators is that users can read the current value of the data being tracked in their programs. This could be useful to e.g. decide to short-circuit a Spark stage depending on how things are going. ![counters](https://cloud.githubusercontent.com/assets/320616/3488815/6ee7bc34-0505-11e4-84ce-e36d9886e2cf.png) Author: Patrick Wendell Closes #1309 from pwendell/metrics and squashes the following commits: 8815308 [Patrick Wendell] Merge remote-tracking branch 'apache/master' into HEAD 93fbe0f [Patrick Wendell] Other minor fixes cc43f68 [Patrick Wendell] Updating unit tests c991b1b [Patrick Wendell] Moving some code into the Accumulators class 9a9ba3c [Patrick Wendell] More merge fixes c5ace9e [Patrick Wendell] More merge conflicts 1da15e3 [Patrick Wendell] Merge remote-tracking branch 'apache/master' into metrics 9860c55 [Patrick Wendell] Potential solution to posting listener events 0bb0e33 [Patrick Wendell] Remove "display" variable and assume display = name.isDefined 0ec4ac7 [Patrick Wendell] Java API's e95bf69 [Patrick Wendell] Stash be97261 [Patrick Wendell] Style fix 8407308 [Patrick Wendell] Removing examples in Hadoop and RDD class 64d405f [Patrick Wendell] Adding missing file 5d8b156 [Patrick Wendell] Changes based on Kay's review. 9f18bad [Patrick Wendell] Minor style changes and tests 7a63abc [Patrick Wendell] Adding Json serialization and responding to Reynold's feedback ad85076 [Patrick Wendell] Example of using named accumulators for custom RDD metrics. 0b72660 [Patrick Wendell] Initial WIP example of supporing globally named accumulators. --- .../scala/org/apache/spark/Accumulators.scala | 19 ++++-- .../scala/org/apache/spark/SparkContext.scala | 19 ++++++ .../spark/api/java/JavaSparkContext.scala | 59 ++++++++++++++++++ .../spark/scheduler/AccumulableInfo.scala | 46 ++++++++++++++ .../apache/spark/scheduler/DAGScheduler.scala | 24 ++++++- .../apache/spark/scheduler/StageInfo.scala | 4 ++ .../org/apache/spark/scheduler/TaskInfo.scala | 9 +++ .../spark/ui/jobs/JobProgressListener.scala | 10 ++- .../org/apache/spark/ui/jobs/StagePage.scala | 21 ++++++- .../org/apache/spark/ui/jobs/UIData.scala | 3 +- .../org/apache/spark/util/JsonProtocol.scala | 39 +++++++++++- .../apache/spark/util/JsonProtocolSuite.scala | 62 +++++++++++++++---- docs/programming-guide.md | 6 +- 13 files changed, 294 insertions(+), 27 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala diff --git a/core/src/main/scala/org/apache/spark/Accumulators.scala b/core/src/main/scala/org/apache/spark/Accumulators.scala index 9c55bfbb47626..12f2fe031cb1d 100644 --- a/core/src/main/scala/org/apache/spark/Accumulators.scala +++ b/core/src/main/scala/org/apache/spark/Accumulators.scala @@ -36,15 +36,21 @@ import org.apache.spark.serializer.JavaSerializer * * @param initialValue initial value of accumulator * @param param helper object defining how to add elements of type `R` and `T` + * @param name human-readable name for use in Spark's web UI * @tparam R the full accumulated data (result type) * @tparam T partial data that can be added in */ class Accumulable[R, T] ( @transient initialValue: R, - param: AccumulableParam[R, T]) + param: AccumulableParam[R, T], + val name: Option[String]) extends Serializable { - val id = Accumulators.newId + def this(@transient initialValue: R, param: AccumulableParam[R, T]) = + this(initialValue, param, None) + + val id: Long = Accumulators.newId + @transient private var value_ = initialValue // Current value on master val zero = param.zero(initialValue) // Zero value to be passed to workers private var deserialized = false @@ -219,8 +225,10 @@ GrowableAccumulableParam[R <% Growable[T] with TraversableOnce[T] with Serializa * @param param helper object defining how to add elements of type `T` * @tparam T result type */ -class Accumulator[T](@transient initialValue: T, param: AccumulatorParam[T]) - extends Accumulable[T,T](initialValue, param) +class Accumulator[T](@transient initialValue: T, param: AccumulatorParam[T], name: Option[String]) + extends Accumulable[T,T](initialValue, param, name) { + def this(initialValue: T, param: AccumulatorParam[T]) = this(initialValue, param, None) +} /** * A simpler version of [[org.apache.spark.AccumulableParam]] where the only data type you can add @@ -281,4 +289,7 @@ private object Accumulators { } } } + + def stringifyPartialValue(partialValue: Any) = "%s".format(partialValue) + def stringifyValue(value: Any) = "%s".format(value) } diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 9ba21cfcde01a..e132955f0f850 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -760,6 +760,15 @@ class SparkContext(config: SparkConf) extends Logging { def accumulator[T](initialValue: T)(implicit param: AccumulatorParam[T]) = new Accumulator(initialValue, param) + /** + * Create an [[org.apache.spark.Accumulator]] variable of a given type, with a name for display + * in the Spark UI. Tasks can "add" values to the accumulator using the `+=` method. Only the + * driver can access the accumulator's `value`. + */ + def accumulator[T](initialValue: T, name: String)(implicit param: AccumulatorParam[T]) = { + new Accumulator(initialValue, param, Some(name)) + } + /** * Create an [[org.apache.spark.Accumulable]] shared variable, to which tasks can add values * with `+=`. Only the driver can access the accumuable's `value`. @@ -769,6 +778,16 @@ class SparkContext(config: SparkConf) extends Logging { def accumulable[T, R](initialValue: T)(implicit param: AccumulableParam[T, R]) = new Accumulable(initialValue, param) + /** + * Create an [[org.apache.spark.Accumulable]] shared variable, with a name for display in the + * Spark UI. Tasks can add values to the accumuable using the `+=` operator. Only the driver can + * access the accumuable's `value`. + * @tparam T accumulator type + * @tparam R type that can be added to the accumulator + */ + def accumulable[T, R](initialValue: T, name: String)(implicit param: AccumulableParam[T, R]) = + new Accumulable(initialValue, param, Some(name)) + /** * Create an accumulator from a "mutable collection" type. * diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala index d9d1c5955ca99..e0a4815940db3 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala @@ -429,6 +429,16 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork def intAccumulator(initialValue: Int): Accumulator[java.lang.Integer] = sc.accumulator(initialValue)(IntAccumulatorParam).asInstanceOf[Accumulator[java.lang.Integer]] + /** + * Create an [[org.apache.spark.Accumulator]] integer variable, which tasks can "add" values + * to using the `add` method. Only the master can access the accumulator's `value`. + * + * This version supports naming the accumulator for display in Spark's web UI. + */ + def intAccumulator(initialValue: Int, name: String): Accumulator[java.lang.Integer] = + sc.accumulator(initialValue, name)(IntAccumulatorParam) + .asInstanceOf[Accumulator[java.lang.Integer]] + /** * Create an [[org.apache.spark.Accumulator]] double variable, which tasks can "add" values * to using the `add` method. Only the master can access the accumulator's `value`. @@ -436,12 +446,31 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork def doubleAccumulator(initialValue: Double): Accumulator[java.lang.Double] = sc.accumulator(initialValue)(DoubleAccumulatorParam).asInstanceOf[Accumulator[java.lang.Double]] + /** + * Create an [[org.apache.spark.Accumulator]] double variable, which tasks can "add" values + * to using the `add` method. Only the master can access the accumulator's `value`. + * + * This version supports naming the accumulator for display in Spark's web UI. + */ + def doubleAccumulator(initialValue: Double, name: String): Accumulator[java.lang.Double] = + sc.accumulator(initialValue, name)(DoubleAccumulatorParam) + .asInstanceOf[Accumulator[java.lang.Double]] + /** * Create an [[org.apache.spark.Accumulator]] integer variable, which tasks can "add" values * to using the `add` method. Only the master can access the accumulator's `value`. */ def accumulator(initialValue: Int): Accumulator[java.lang.Integer] = intAccumulator(initialValue) + /** + * Create an [[org.apache.spark.Accumulator]] integer variable, which tasks can "add" values + * to using the `add` method. Only the master can access the accumulator's `value`. + * + * This version supports naming the accumulator for display in Spark's web UI. + */ + def accumulator(initialValue: Int, name: String): Accumulator[java.lang.Integer] = + intAccumulator(initialValue, name) + /** * Create an [[org.apache.spark.Accumulator]] double variable, which tasks can "add" values * to using the `add` method. Only the master can access the accumulator's `value`. @@ -449,6 +478,16 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork def accumulator(initialValue: Double): Accumulator[java.lang.Double] = doubleAccumulator(initialValue) + + /** + * Create an [[org.apache.spark.Accumulator]] double variable, which tasks can "add" values + * to using the `add` method. Only the master can access the accumulator's `value`. + * + * This version supports naming the accumulator for display in Spark's web UI. + */ + def accumulator(initialValue: Double, name: String): Accumulator[java.lang.Double] = + doubleAccumulator(initialValue, name) + /** * Create an [[org.apache.spark.Accumulator]] variable of a given type, which tasks can "add" * values to using the `add` method. Only the master can access the accumulator's `value`. @@ -456,6 +495,16 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork def accumulator[T](initialValue: T, accumulatorParam: AccumulatorParam[T]): Accumulator[T] = sc.accumulator(initialValue)(accumulatorParam) + /** + * Create an [[org.apache.spark.Accumulator]] variable of a given type, which tasks can "add" + * values to using the `add` method. Only the master can access the accumulator's `value`. + * + * This version supports naming the accumulator for display in Spark's web UI. + */ + def accumulator[T](initialValue: T, name: String, accumulatorParam: AccumulatorParam[T]) + : Accumulator[T] = + sc.accumulator(initialValue, name)(accumulatorParam) + /** * Create an [[org.apache.spark.Accumulable]] shared variable of the given type, to which tasks * can "add" values with `add`. Only the master can access the accumuable's `value`. @@ -463,6 +512,16 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork def accumulable[T, R](initialValue: T, param: AccumulableParam[T, R]): Accumulable[T, R] = sc.accumulable(initialValue)(param) + /** + * Create an [[org.apache.spark.Accumulable]] shared variable of the given type, to which tasks + * can "add" values with `add`. Only the master can access the accumuable's `value`. + * + * This version supports naming the accumulator for display in Spark's web UI. + */ + def accumulable[T, R](initialValue: T, name: String, param: AccumulableParam[T, R]) + : Accumulable[T, R] = + sc.accumulable(initialValue, name)(param) + /** * Broadcast a read-only variable to the cluster, returning a * [[org.apache.spark.broadcast.Broadcast]] object for reading it in distributed functions. diff --git a/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala new file mode 100644 index 0000000000000..fa83372bb4d11 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import org.apache.spark.annotation.DeveloperApi + +/** + * :: DeveloperApi :: + * Information about an [[org.apache.spark.Accumulable]] modified during a task or stage. + */ +@DeveloperApi +class AccumulableInfo ( + val id: Long, + val name: String, + val update: Option[String], // represents a partial update within a task + val value: String) { + + override def equals(other: Any): Boolean = other match { + case acc: AccumulableInfo => + this.id == acc.id && this.name == acc.name && + this.update == acc.update && this.value == acc.value + case _ => false + } +} + +object AccumulableInfo { + def apply(id: Long, name: String, update: Option[String], value: String) = + new AccumulableInfo(id, name, update, value) + + def apply(id: Long, name: String, value: String) = new AccumulableInfo(id, name, None, value) +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 9fa3a4e9c71ae..430e45ada5808 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -883,8 +883,14 @@ class DAGScheduler( val task = event.task val stageId = task.stageId val taskType = Utils.getFormattedClassName(task) - listenerBus.post(SparkListenerTaskEnd(stageId, taskType, event.reason, event.taskInfo, - event.taskMetrics)) + + // The success case is dealt with separately below, since we need to compute accumulator + // updates before posting. + if (event.reason != Success) { + listenerBus.post(SparkListenerTaskEnd(stageId, taskType, event.reason, event.taskInfo, + event.taskMetrics)) + } + if (!stageIdToStage.contains(task.stageId)) { // Skip all the actions if the stage has been cancelled. return @@ -906,12 +912,26 @@ class DAGScheduler( if (event.accumUpdates != null) { try { Accumulators.add(event.accumUpdates) + event.accumUpdates.foreach { case (id, partialValue) => + val acc = Accumulators.originals(id).asInstanceOf[Accumulable[Any, Any]] + // To avoid UI cruft, ignore cases where value wasn't updated + if (acc.name.isDefined && partialValue != acc.zero) { + val name = acc.name.get + val stringPartialValue = Accumulators.stringifyPartialValue(partialValue) + val stringValue = Accumulators.stringifyValue(acc.value) + stage.info.accumulables(id) = AccumulableInfo(id, name, stringValue) + event.taskInfo.accumulables += + AccumulableInfo(id, name, Some(stringPartialValue), stringValue) + } + } } catch { // If we see an exception during accumulator update, just log the error and move on. case e: Exception => logError(s"Failed to update accumulators for $task", e) } } + listenerBus.post(SparkListenerTaskEnd(stageId, taskType, event.reason, event.taskInfo, + event.taskMetrics)) stage.pendingTasks -= task task match { case rt: ResultTask[_, _] => diff --git a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala index 480891550eb60..2a407e47a05bd 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala @@ -17,6 +17,8 @@ package org.apache.spark.scheduler +import scala.collection.mutable.HashMap + import org.apache.spark.annotation.DeveloperApi import org.apache.spark.storage.RDDInfo @@ -37,6 +39,8 @@ class StageInfo( var completionTime: Option[Long] = None /** If the stage failed, the reason why. */ var failureReason: Option[String] = None + /** Terminal values of accumulables updated during this stage. */ + val accumulables = HashMap[Long, AccumulableInfo]() def stageFailed(reason: String) { failureReason = Some(reason) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala index ca0595f35143e..6fa1f2c880f7a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala @@ -17,6 +17,8 @@ package org.apache.spark.scheduler +import scala.collection.mutable.ListBuffer + import org.apache.spark.annotation.DeveloperApi /** @@ -41,6 +43,13 @@ class TaskInfo( */ var gettingResultTime: Long = 0 + /** + * Intermediate updates to accumulables during this task. Note that it is valid for the same + * accumulable to be updated multiple times in a single task or for two accumulables with the + * same name but different IDs to exist in a task. + */ + val accumulables = ListBuffer[AccumulableInfo]() + /** * The time when the task has completed successfully (including the time to remotely fetch * results, if necessary). diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala index da2f5d3172fe2..a57a354620163 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala @@ -17,7 +17,7 @@ package org.apache.spark.ui.jobs -import scala.collection.mutable.{HashMap, ListBuffer} +import scala.collection.mutable.{HashMap, ListBuffer, Map} import org.apache.spark._ import org.apache.spark.annotation.DeveloperApi @@ -65,6 +65,10 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { new StageUIData }) + for ((id, info) <- stageCompleted.stageInfo.accumulables) { + stageData.accumulables(id) = info + } + poolToActiveStages.get(stageData.schedulingPool).foreach(_.remove(stageId)) activeStages.remove(stageId) if (stage.failureReason.isEmpty) { @@ -130,6 +134,10 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { new StageUIData }) + for (accumulableInfo <- info.accumulables) { + stageData.accumulables(accumulableInfo.id) = accumulableInfo + } + val execSummaryMap = stageData.executorSummary val execSummary = execSummaryMap.getOrElseUpdate(info.executorId, new ExecutorSummary) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index cab26b9e2f7d3..8bc1ba758cf77 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -20,11 +20,12 @@ package org.apache.spark.ui.jobs import java.util.Date import javax.servlet.http.HttpServletRequest -import scala.xml.Node +import scala.xml.{Node, Unparsed} import org.apache.spark.ui.{ToolTips, WebUIPage, UIUtils} import org.apache.spark.ui.jobs.UIData._ import org.apache.spark.util.{Utils, Distribution} +import org.apache.spark.scheduler.AccumulableInfo /** Page showing statistics and task list for a given stage */ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { @@ -51,6 +52,7 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { val tasks = stageData.taskData.values.toSeq.sortBy(_.taskInfo.launchTime) val numCompleted = tasks.count(_.taskInfo.finished) + val accumulables = listener.stageIdToData(stageId).accumulables val hasInput = stageData.inputBytes > 0 val hasShuffleRead = stageData.shuffleReadBytes > 0 val hasShuffleWrite = stageData.shuffleWriteBytes > 0 @@ -95,10 +97,15 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { // scalastyle:on + val accumulableHeaders: Seq[String] = Seq("Accumulable", "Value") + def accumulableRow(acc: AccumulableInfo) = {acc.name}{acc.value} + val accumulableTable = UIUtils.listingTable(accumulableHeaders, accumulableRow, + accumulables.values.toSeq) + val taskHeaders: Seq[String] = Seq( "Index", "ID", "Attempt", "Status", "Locality Level", "Executor", - "Launch Time", "Duration", "GC Time") ++ + "Launch Time", "Duration", "GC Time", "Accumulators") ++ {if (hasInput) Seq("Input") else Nil} ++ {if (hasShuffleRead) Seq("Shuffle Read") else Nil} ++ {if (hasShuffleWrite) Seq("Write Time", "Shuffle Write") else Nil} ++ @@ -208,11 +215,16 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { Some(UIUtils.listingTable(quantileHeaders, quantileRow, listings, fixedWidth = true)) } val executorTable = new ExecutorTable(stageId, parent) + + val maybeAccumulableTable: Seq[Node] = + if (accumulables.size > 0) {

    Accumulators

    ++ accumulableTable } else Seq() + val content = summary ++

    Summary Metrics for {numCompleted} Completed Tasks

    ++
    {summaryTable.getOrElse("No tasks have reported metrics yet.")}
    ++

    Aggregated Metrics by Executor

    ++ executorTable.toNodeSeq ++ + maybeAccumulableTable ++

    Tasks

    ++ taskTable UIUtils.headerSparkPage(content, basePath, appName, "Details for Stage %d".format(stageId), @@ -279,6 +291,11 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { {if (gcTime > 0) UIUtils.formatDuration(gcTime) else ""} + + {Unparsed( + info.accumulables.map{acc => s"${acc.name}: ${acc.update.get}"}.mkString("
    ") + )} + - - Browser - Standalone Cluster Master - 8080 - Web UI - spark.master.ui.port - Jetty-based - - - Browser - Driver - 4040 - Web UI - spark.ui.port - Jetty-based - - - Browser - History Server - 18080 - Web UI - spark.history.ui.port - Jetty-based - - - Browser - Worker - 8081 - Web UI - spark.worker.ui.port - Jetty-based - - - - Application - Standalone Cluster Master - 7077 - Submit job to cluster - spark.driver.port - Akka-based. Set to "0" to choose a port randomly - - - Worker - Standalone Cluster Master - 7077 - Join cluster - spark.driver.port - Akka-based. Set to "0" to choose a port randomly - - - Application - Worker - (random) - Join cluster - SPARK_WORKER_PORT (standalone cluster) - Akka-based - - - - - Driver and other Workers - Worker - (random) - -
      -
    • File server for file and jars
    • -
    • Http Broadcast
    • -
    • Class file server (Spark Shell only)
    • -
    - - None - Jetty-based. Each of these services starts on a random port that cannot be configured - - - +Spark makes heavy use of the network, and some environments have strict requirements for using +tight firewall settings. For a complete list of ports to configure, see the +[security page](security.html#configuring-ports-for-network-security). # High Availability By default, standalone scheduling clusters are resilient to Worker failures (insofar as Spark itself is resilient to losing work by moving it to other workers). However, the scheduler uses a Master to make scheduling decisions, and this (by default) creates a single point of failure: if the Master crashes, no new applications can be created. In order to circumvent this, we have two high availability schemes, detailed below. -## Standby Masters with ZooKeeper +# Standby Masters with ZooKeeper **Overview** @@ -429,7 +347,7 @@ There's an important distinction to be made between "registering with a Master" Due to this property, new Masters can be created at any time, and the only thing you need to worry about is that _new_ applications and Workers can find it to register with in case it becomes the leader. Once registered, you're taken care of. -## Single-Node Recovery with Local File System +# Single-Node Recovery with Local File System **Overview** diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index aac621fe53938..40b588512ff08 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -330,6 +330,8 @@ object TestSettings { fork := true, javaOptions in Test += "-Dspark.test.home=" + sparkHome, javaOptions in Test += "-Dspark.testing=1", + javaOptions in Test += "-Dspark.ports.maxRetries=100", + javaOptions in Test += "-Dspark.ui.port=0", javaOptions in Test += "-Dsun.io.serialization.extendedDebugInfo=true", javaOptions in Test ++= System.getProperties.filter(_._1 startsWith "spark") .map { case (k,v) => s"-D$k=$v" }.toSeq, diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala b/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala index f60bbb4662af1..84b57cd2dc1af 100644 --- a/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala +++ b/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala @@ -102,7 +102,8 @@ import org.apache.spark.util.Utils val virtualDirectory = new PlainFile(outputDir) // "directory" for classfiles /** Jetty server that will serve our classes to worker nodes */ - val classServer = new HttpServer(outputDir, new SecurityManager(conf)) + val classServerPort = conf.getInt("spark.replClassServer.port", 0) + val classServer = new HttpServer(outputDir, new SecurityManager(conf), classServerPort, "HTTP class server") private var currentSettings: Settings = initialSettings var printResults = true // whether to print result lines var totalSilence = false // whether to print anything From 27a8d4ce39aa620a5926b33371fcf03bbcb18698 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 6 Aug 2014 11:08:12 -0700 Subject: [PATCH 056/231] [SPARK-2875] [PySpark] [SQL] handle null in schemaRDD() Handle null in schemaRDD during converting them into Python. Author: Davies Liu Closes #1802 from davies/json and squashes the following commits: 88e6b1f [Davies Liu] handle null in schemaRDD() (cherry picked from commit 48789117c2dd6d38e0bd8d21cdbcb989913205a6) Signed-off-by: Michael Armbrust --- python/pyspark/sql.py | 7 +++++ .../org/apache/spark/sql/SchemaRDD.scala | 27 +++++++++++-------- 2 files changed, 23 insertions(+), 11 deletions(-) diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index f1093701ddc89..adc56e7ec0e2b 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -1231,6 +1231,13 @@ def jsonRDD(self, rdd, schema=None): ... "field3.field5[0] as f3 from table3") >>> srdd6.collect() [Row(f1=u'row1', f2=None,...Row(f1=u'row3', f2=[], f3=None)] + + >>> sqlCtx.jsonRDD(sc.parallelize(['{}', + ... '{"key0": {"key1": "value1"}}'])).collect() + [Row(key0=None), Row(key0=Row(key1=u'value1'))] + >>> sqlCtx.jsonRDD(sc.parallelize(['{"key0": null}', + ... '{"key0": {"key1": "value1"}}'])).collect() + [Row(key0=None), Row(key0=Row(key1=u'value1'))] """ def func(iterator): diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala index 57df79321b35d..33b2ed1b3a399 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala @@ -382,21 +382,26 @@ class SchemaRDD( private[sql] def javaToPython: JavaRDD[Array[Byte]] = { import scala.collection.Map - def toJava(obj: Any, dataType: DataType): Any = dataType match { - case struct: StructType => rowToArray(obj.asInstanceOf[Row], struct) - case array: ArrayType => obj match { - case seq: Seq[Any] => seq.map(x => toJava(x, array.elementType)).asJava - case list: JList[_] => list.map(x => toJava(x, array.elementType)).asJava - case arr if arr != null && arr.getClass.isArray => - arr.asInstanceOf[Array[Any]].map(x => toJava(x, array.elementType)) - case other => other - } - case mt: MapType => obj.asInstanceOf[Map[_, _]].map { + def toJava(obj: Any, dataType: DataType): Any = (obj, dataType) match { + case (null, _) => null + + case (obj: Row, struct: StructType) => rowToArray(obj, struct) + + case (seq: Seq[Any], array: ArrayType) => + seq.map(x => toJava(x, array.elementType)).asJava + case (list: JList[_], array: ArrayType) => + list.map(x => toJava(x, array.elementType)).asJava + case (arr, array: ArrayType) if arr.getClass.isArray => + arr.asInstanceOf[Array[Any]].map(x => toJava(x, array.elementType)) + + case (obj: Map[_, _], mt: MapType) => obj.map { case (k, v) => (k, toJava(v, mt.valueType)) // key should be primitive type }.asJava + // Pyrolite can handle Timestamp - case other => obj + case (other, _) => other } + def rowToArray(row: Row, structType: StructType): Array[Any] = { val fields = structType.fields.map(field => field.dataType) row.zip(fields).map { From cf8e7fd5e18509531dc1ab04384d18a2f11330c2 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Wed, 6 Aug 2014 12:28:35 -0700 Subject: [PATCH 057/231] [SPARK-2678][Core][SQL] A workaround for SPARK-2678 JIRA issues: - Main: [SPARK-2678](https://issues.apache.org/jira/browse/SPARK-2678) - Related: [SPARK-2874](https://issues.apache.org/jira/browse/SPARK-2874) Related PR: - #1715 This PR is both a fix for SPARK-2874 and a workaround for SPARK-2678. Fixing SPARK-2678 completely requires some API level changes that need further discussion, and we decided not to include it in Spark 1.1 release. As currently SPARK-2678 only affects Spark SQL scripts, this workaround is enough for Spark 1.1. Command line option handling logic in bash scripts looks somewhat dirty and duplicated, but it helps to provide a cleaner user interface as well as retain full downward compatibility for now. Author: Cheng Lian Closes #1801 from liancheng/spark-2874 and squashes the following commits: 8045d7a [Cheng Lian] Make sure test suites pass 8493a9e [Cheng Lian] Using eval to retain quoted arguments aed523f [Cheng Lian] Fixed typo in bin/spark-sql f12a0b1 [Cheng Lian] Worked arount SPARK-2678 daee105 [Cheng Lian] Fixed usage messages of all Spark SQL related scripts (cherry picked from commit a6cd31108f0d73ce6823daafe8447677e03cfd13) Signed-off-by: Patrick Wendell --- bin/beeline | 29 ++------ bin/spark-sql | 66 +++++++++++++++++-- .../spark/deploy/SparkSubmitArguments.scala | 39 ++++------- .../spark/deploy/SparkSubmitSuite.scala | 12 ++++ sbin/start-thriftserver.sh | 50 ++++++++++++-- .../hive/thriftserver/HiveThriftServer2.scala | 1 - .../sql/hive/thriftserver/CliSuite.scala | 19 +++--- .../thriftserver/HiveThriftServer2Suite.scala | 23 ++++--- 8 files changed, 164 insertions(+), 75 deletions(-) diff --git a/bin/beeline b/bin/beeline index 09fe366c609fa..1bda4dba50605 100755 --- a/bin/beeline +++ b/bin/beeline @@ -17,29 +17,14 @@ # limitations under the License. # -# Figure out where Spark is installed -FWDIR="$(cd `dirname $0`/..; pwd)" +# +# Shell script for starting BeeLine -# Find the java binary -if [ -n "${JAVA_HOME}" ]; then - RUNNER="${JAVA_HOME}/bin/java" -else - if [ `command -v java` ]; then - RUNNER="java" - else - echo "JAVA_HOME is not set" >&2 - exit 1 - fi -fi +# Enter posix mode for bash +set -o posix -# Compute classpath using external script -classpath_output=$($FWDIR/bin/compute-classpath.sh) -if [[ "$?" != "0" ]]; then - echo "$classpath_output" - exit 1 -else - CLASSPATH=$classpath_output -fi +# Figure out where Spark is installed +FWDIR="$(cd `dirname $0`/..; pwd)" CLASS="org.apache.hive.beeline.BeeLine" -exec "$RUNNER" -cp "$CLASSPATH" $CLASS "$@" +exec "$FWDIR/bin/spark-class" $CLASS "$@" diff --git a/bin/spark-sql b/bin/spark-sql index bba7f897b19bc..61ebd8ab6dec8 100755 --- a/bin/spark-sql +++ b/bin/spark-sql @@ -23,14 +23,72 @@ # Enter posix mode for bash set -o posix +CLASS="org.apache.spark.sql.hive.thriftserver.SparkSQLCLIDriver" + # Figure out where Spark is installed FWDIR="$(cd `dirname $0`/..; pwd)" -if [[ "$@" = *--help ]] || [[ "$@" = *-h ]]; then - echo "Usage: ./sbin/spark-sql [options]" +function usage { + echo "Usage: ./sbin/spark-sql [options] [cli option]" + pattern="usage" + pattern+="\|Spark assembly has been built with Hive" + pattern+="\|NOTE: SPARK_PREPEND_CLASSES is set" + pattern+="\|Spark Command: " + pattern+="\|--help" + pattern+="\|=======" + $FWDIR/bin/spark-submit --help 2>&1 | grep -v Usage 1>&2 + echo + echo "CLI options:" + $FWDIR/bin/spark-class $CLASS --help 2>&1 | grep -v "$pattern" 1>&2 +} + +function ensure_arg_number { + arg_number=$1 + at_least=$2 + + if [[ $arg_number -lt $at_least ]]; then + usage + exit 1 + fi +} + +if [[ "$@" = --help ]] || [[ "$@" = -h ]]; then + usage exit 0 fi -CLASS="org.apache.spark.sql.hive.thriftserver.SparkSQLCLIDriver" -exec "$FWDIR"/bin/spark-submit --class $CLASS spark-internal $@ +CLI_ARGS=() +SUBMISSION_ARGS=() + +while (($#)); do + case $1 in + -d | --define | --database | -f | -h | --hiveconf | --hivevar | -i | -p) + ensure_arg_number $# 2 + CLI_ARGS+=($1); shift + CLI_ARGS+=($1); shift + ;; + + -e) + ensure_arg_number $# 2 + CLI_ARGS+=($1); shift + CLI_ARGS+=(\"$1\"); shift + ;; + + -s | --silent) + CLI_ARGS+=($1); shift + ;; + + -v | --verbose) + # Both SparkSubmit and SparkSQLCLIDriver recognizes -v | --verbose + CLI_ARGS+=($1) + SUBMISSION_ARGS+=($1); shift + ;; + + *) + SUBMISSION_ARGS+=($1); shift + ;; + esac +done + +eval exec "$FWDIR"/bin/spark-submit --class $CLASS ${SUBMISSION_ARGS[*]} spark-internal ${CLI_ARGS[*]} diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index 9391f24e71ed7..087dd4d633db0 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -220,6 +220,7 @@ private[spark] class SparkSubmitArguments(args: Seq[String]) { /** Fill in values by parsing user options. */ private def parseOpts(opts: Seq[String]): Unit = { var inSparkOpts = true + val EQ_SEPARATED_OPT="""(--[^=]+)=(.+)""".r // Delineates parsing of Spark options from parsing of user options. parse(opts) @@ -322,33 +323,21 @@ private[spark] class SparkSubmitArguments(args: Seq[String]) { verbose = true parse(tail) + case EQ_SEPARATED_OPT(opt, value) :: tail => + parse(opt :: value :: tail) + + case value :: tail if value.startsWith("-") => + SparkSubmit.printErrorAndExit(s"Unrecognized option '$value'.") + case value :: tail => - if (inSparkOpts) { - value match { - // convert --foo=bar to --foo bar - case v if v.startsWith("--") && v.contains("=") && v.split("=").size == 2 => - val parts = v.split("=") - parse(Seq(parts(0), parts(1)) ++ tail) - case v if v.startsWith("-") => - val errMessage = s"Unrecognized option '$value'." - SparkSubmit.printErrorAndExit(errMessage) - case v => - primaryResource = - if (!SparkSubmit.isShell(v) && !SparkSubmit.isInternal(v)) { - Utils.resolveURI(v).toString - } else { - v - } - inSparkOpts = false - isPython = SparkSubmit.isPython(v) - parse(tail) + primaryResource = + if (!SparkSubmit.isShell(value) && !SparkSubmit.isInternal(value)) { + Utils.resolveURI(value).toString + } else { + value } - } else { - if (!value.isEmpty) { - childArgs += value - } - parse(tail) - } + isPython = SparkSubmit.isPython(value) + childArgs ++= tail case Nil => } diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index a5cdcfb5de03b..7e1ef80c84561 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -106,6 +106,18 @@ class SparkSubmitSuite extends FunSuite with Matchers { appArgs.childArgs should be (Seq("some", "--weird", "args")) } + test("handles arguments to user program with name collision") { + val clArgs = Seq( + "--name", "myApp", + "--class", "Foo", + "userjar.jar", + "--master", "local", + "some", + "--weird", "args") + val appArgs = new SparkSubmitArguments(clArgs) + appArgs.childArgs should be (Seq("--master", "local", "some", "--weird", "args")) + } + test("handles YARN cluster mode") { val clArgs = Seq( "--deploy-mode", "cluster", diff --git a/sbin/start-thriftserver.sh b/sbin/start-thriftserver.sh index 8398e6f19b511..603f50ae13240 100755 --- a/sbin/start-thriftserver.sh +++ b/sbin/start-thriftserver.sh @@ -26,11 +26,53 @@ set -o posix # Figure out where Spark is installed FWDIR="$(cd `dirname $0`/..; pwd)" -if [[ "$@" = *--help ]] || [[ "$@" = *-h ]]; then - echo "Usage: ./sbin/start-thriftserver [options]" +CLASS="org.apache.spark.sql.hive.thriftserver.HiveThriftServer2" + +function usage { + echo "Usage: ./sbin/start-thriftserver [options] [thrift server options]" + pattern="usage" + pattern+="\|Spark assembly has been built with Hive" + pattern+="\|NOTE: SPARK_PREPEND_CLASSES is set" + pattern+="\|Spark Command: " + pattern+="\|=======" + pattern+="\|--help" + $FWDIR/bin/spark-submit --help 2>&1 | grep -v Usage 1>&2 + echo + echo "Thrift server options:" + $FWDIR/bin/spark-class $CLASS --help 2>&1 | grep -v "$pattern" 1>&2 +} + +function ensure_arg_number { + arg_number=$1 + at_least=$2 + + if [[ $arg_number -lt $at_least ]]; then + usage + exit 1 + fi +} + +if [[ "$@" = --help ]] || [[ "$@" = -h ]]; then + usage exit 0 fi -CLASS="org.apache.spark.sql.hive.thriftserver.HiveThriftServer2" -exec "$FWDIR"/bin/spark-submit --class $CLASS spark-internal $@ +THRIFT_SERVER_ARGS=() +SUBMISSION_ARGS=() + +while (($#)); do + case $1 in + --hiveconf) + ensure_arg_number $# 2 + THRIFT_SERVER_ARGS+=($1); shift + THRIFT_SERVER_ARGS+=($1); shift + ;; + + *) + SUBMISSION_ARGS+=($1); shift + ;; + esac +done + +eval exec "$FWDIR"/bin/spark-submit --class $CLASS ${SUBMISSION_ARGS[*]} spark-internal ${THRIFT_SERVER_ARGS[*]} diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala index 08d3f983d9e71..6f7942aba314a 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala @@ -40,7 +40,6 @@ private[hive] object HiveThriftServer2 extends Logging { val optionsProcessor = new ServerOptionsProcessor("HiveThriftServer2") if (!optionsProcessor.process(args)) { - logWarning("Error starting HiveThriftServer2 with given arguments") System.exit(-1) } diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala index 69f19f826a802..2bf8cfdcacd22 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.hive.thriftserver import java.io.{BufferedReader, InputStreamReader, PrintWriter} +import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.scalatest.{BeforeAndAfterAll, FunSuite} class CliSuite extends FunSuite with BeforeAndAfterAll with TestUtils { @@ -27,15 +28,15 @@ class CliSuite extends FunSuite with BeforeAndAfterAll with TestUtils { val METASTORE_PATH = TestUtils.getMetastorePath("cli") override def beforeAll() { - val pb = new ProcessBuilder( - "../../bin/spark-sql", - "--master", - "local", - "--hiveconf", - s"javax.jdo.option.ConnectionURL=jdbc:derby:;databaseName=$METASTORE_PATH;create=true", - "--hiveconf", - "hive.metastore.warehouse.dir=" + WAREHOUSE_PATH) - + val jdbcUrl = s"jdbc:derby:;databaseName=$METASTORE_PATH;create=true" + val commands = + s"""../../bin/spark-sql + | --master local + | --hiveconf ${ConfVars.METASTORECONNECTURLKEY}="$jdbcUrl" + | --hiveconf ${ConfVars.METASTOREWAREHOUSE}=$WAREHOUSE_PATH + """.stripMargin.split("\\s+") + + val pb = new ProcessBuilder(commands: _*) process = pb.start() outputWriter = new PrintWriter(process.getOutputStream, true) inputReader = new BufferedReader(new InputStreamReader(process.getInputStream)) diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala index b7b7c9957ac34..78bffa2607349 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala @@ -25,6 +25,7 @@ import java.io.{BufferedReader, InputStreamReader} import java.net.ServerSocket import java.sql.{Connection, DriverManager, Statement} +import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.scalatest.{BeforeAndAfterAll, FunSuite} import org.apache.spark.Logging @@ -63,16 +64,18 @@ class HiveThriftServer2Suite extends FunSuite with BeforeAndAfterAll with TestUt // Forking a new process to start the Hive Thrift server. The reason to do this is it is // hard to clean up Hive resources entirely, so we just start a new process and kill // that process for cleanup. - val defaultArgs = Seq( - "../../sbin/start-thriftserver.sh", - "--master local", - "--hiveconf", - "hive.root.logger=INFO,console", - "--hiveconf", - s"javax.jdo.option.ConnectionURL=jdbc:derby:;databaseName=$METASTORE_PATH;create=true", - "--hiveconf", - s"hive.metastore.warehouse.dir=$WAREHOUSE_PATH") - val pb = new ProcessBuilder(defaultArgs ++ args) + val jdbcUrl = s"jdbc:derby:;databaseName=$METASTORE_PATH;create=true" + val command = + s"""../../sbin/start-thriftserver.sh + | --master local + | --hiveconf hive.root.logger=INFO,console + | --hiveconf ${ConfVars.METASTORECONNECTURLKEY}="$jdbcUrl" + | --hiveconf ${ConfVars.METASTOREWAREHOUSE}=$METASTORE_PATH + | --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_BIND_HOST}=$HOST + | --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_PORT}=$PORT + """.stripMargin.split("\\s+") + + val pb = new ProcessBuilder(command ++ args: _*) val environment = pb.environment() environment.put("HIVE_SERVER2_THRIFT_PORT", PORT.toString) environment.put("HIVE_SERVER2_THRIFT_BIND_HOST", HOST) From 4c19614e94d9c26109e5ffc6cf83665fab0bad84 Mon Sep 17 00:00:00 2001 From: Nicholas Chammas Date: Wed, 6 Aug 2014 12:58:24 -0700 Subject: [PATCH 058/231] [SPARK-2627] [PySpark] have the build enforce PEP 8 automatically As described in [SPARK-2627](https://issues.apache.org/jira/browse/SPARK-2627), we'd like Python code to automatically be checked for PEP 8 compliance by Jenkins. This pull request aims to do that. Notes: * We may need to install [`pep8`](https://pypi.python.org/pypi/pep8) on the build server. * I'm expecting tests to fail now that PEP 8 compliance is being checked as part of the build. I'm fine with cleaning up any remaining PEP 8 violations as part of this pull request. * I did not understand why the RAT and scalastyle reports are saved to text files. I did the same for the PEP 8 check, but only so that the console output style can match those for the RAT and scalastyle checks. The PEP 8 report is removed right after the check is complete. * Updates to the ["Contributing to Spark"](https://cwiki.apache.org/confluence/display/SPARK/Contributing+to+Spark) guide will be submitted elsewhere, as I don't believe that text is part of the Spark repo. Author: Nicholas Chammas Author: nchammas Closes #1744 from nchammas/master and squashes the following commits: 274b238 [Nicholas Chammas] [SPARK-2627] [PySpark] minor indentation changes 983d963 [nchammas] Merge pull request #5 from apache/master 1db5314 [nchammas] Merge pull request #4 from apache/master 0e0245f [Nicholas Chammas] [SPARK-2627] undo erroneous whitespace fixes bf30942 [Nicholas Chammas] [SPARK-2627] PEP8: comment spacing 6db9a44 [nchammas] Merge pull request #3 from apache/master 7b4750e [Nicholas Chammas] merge upstream changes 91b7584 [Nicholas Chammas] [SPARK-2627] undo unnecessary line breaks 44e3e56 [Nicholas Chammas] [SPARK-2627] use tox.ini to exclude files b09fae2 [Nicholas Chammas] don't wrap comments unnecessarily bfb9f9f [Nicholas Chammas] [SPARK-2627] keep up with the PEP 8 fixes 9da347f [nchammas] Merge pull request #2 from apache/master aa5b4b5 [Nicholas Chammas] [SPARK-2627] follow Spark bash style for if blocks d0a83b9 [Nicholas Chammas] [SPARK-2627] check that pep8 downloaded fine dffb5dd [Nicholas Chammas] [SPARK-2627] download pep8 at runtime a1ce7ae [Nicholas Chammas] [SPARK-2627] space out test report sections 21da538 [Nicholas Chammas] [SPARK-2627] it's PEP 8, not PEP8 6f4900b [Nicholas Chammas] [SPARK-2627] more misc PEP 8 fixes fe57ed0 [Nicholas Chammas] removing merge conflict backups 9c01d4c [nchammas] Merge pull request #1 from apache/master 9a66cb0 [Nicholas Chammas] resolving merge conflicts a31ccc4 [Nicholas Chammas] [SPARK-2627] miscellaneous PEP 8 fixes beaa9ac [Nicholas Chammas] [SPARK-2627] fail check on non-zero status 723ed39 [Nicholas Chammas] always delete the report file 0541ebb [Nicholas Chammas] [SPARK-2627] call Python linter from run-tests 12440fa [Nicholas Chammas] [SPARK-2627] add Scala linter 61c07b9 [Nicholas Chammas] [SPARK-2627] add Python linter 75ad552 [Nicholas Chammas] make check output style consistent (cherry picked from commit d614967b0bad1e6c5277d612602ec0a653a00258) Signed-off-by: Reynold Xin --- dev/lint-python | 60 +++++++++++ dev/lint-scala | 23 ++++ dev/run-tests | 13 ++- dev/scalastyle | 2 +- python/pyspark/accumulators.py | 7 ++ python/pyspark/broadcast.py | 1 + python/pyspark/conf.py | 1 + python/pyspark/context.py | 25 +++-- python/pyspark/daemon.py | 5 +- python/pyspark/files.py | 1 + python/pyspark/java_gateway.py | 1 + python/pyspark/mllib/_common.py | 5 +- python/pyspark/mllib/classification.py | 8 ++ python/pyspark/mllib/clustering.py | 3 + python/pyspark/mllib/linalg.py | 2 + python/pyspark/mllib/random.py | 14 +-- python/pyspark/mllib/recommendation.py | 2 + python/pyspark/mllib/regression.py | 12 +++ python/pyspark/mllib/stat.py | 1 + python/pyspark/mllib/tests.py | 11 +- python/pyspark/mllib/tree.py | 4 +- python/pyspark/mllib/util.py | 1 + python/pyspark/rdd.py | 22 ++-- python/pyspark/rddsampler.py | 4 + python/pyspark/resultiterable.py | 2 + python/pyspark/serializers.py | 21 +++- python/pyspark/shuffle.py | 20 ++-- python/pyspark/sql.py | 66 ++++++++---- python/pyspark/storagelevel.py | 1 + python/pyspark/tests.py | 143 ++++++++++++++----------- python/test_support/userlibrary.py | 2 + tox.ini | 1 + 32 files changed, 348 insertions(+), 136 deletions(-) create mode 100755 dev/lint-python create mode 100755 dev/lint-scala diff --git a/dev/lint-python b/dev/lint-python new file mode 100755 index 0000000000000..4efddad839387 --- /dev/null +++ b/dev/lint-python @@ -0,0 +1,60 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +SCRIPT_DIR="$( cd "$( dirname "$0" )" && pwd )" +SPARK_ROOT_DIR="$(dirname $SCRIPT_DIR)" +PEP8_REPORT_PATH="$SPARK_ROOT_DIR/dev/pep8-report.txt" + +cd $SPARK_ROOT_DIR + +# Get pep8 at runtime so that we don't rely on it being installed on the build server. +#+ See: https://github.com/apache/spark/pull/1744#issuecomment-50982162 +#+ TODOs: +#+ - Dynamically determine latest release version of pep8 and use that. +#+ - Download this from a more reliable source. (GitHub raw can be flaky, apparently. (?)) +PEP8_SCRIPT_PATH="$SPARK_ROOT_DIR/dev/pep8.py" +PEP8_SCRIPT_REMOTE_PATH="https://raw.githubusercontent.com/jcrocholl/pep8/1.5.7/pep8.py" + +curl --silent -o "$PEP8_SCRIPT_PATH" "$PEP8_SCRIPT_REMOTE_PATH" +curl_status=$? + +if [ $curl_status -ne 0 ]; then + echo "Failed to download pep8.py from \"$PEP8_SCRIPT_REMOTE_PATH\"." + exit $curl_status +fi + + +# There is no need to write this output to a file +#+ first, but we do so so that the check status can +#+ be output before the report, like with the +#+ scalastyle and RAT checks. +python $PEP8_SCRIPT_PATH ./python > "$PEP8_REPORT_PATH" +pep8_status=${PIPESTATUS[0]} #$? + +if [ $pep8_status -ne 0 ]; then + echo "PEP 8 checks failed." + cat "$PEP8_REPORT_PATH" +else + echo "PEP 8 checks passed." +fi + +rm -f "$PEP8_REPORT_PATH" +rm "$PEP8_SCRIPT_PATH" + +exit $pep8_status diff --git a/dev/lint-scala b/dev/lint-scala new file mode 100755 index 0000000000000..c676dfdf4f44e --- /dev/null +++ b/dev/lint-scala @@ -0,0 +1,23 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +SCRIPT_DIR="$( cd "$( dirname "$0" )" && pwd )" +SPARK_ROOT_DIR="$(dirname $SCRIPT_DIR)" + +"$SCRIPT_DIR/scalastyle" diff --git a/dev/run-tests b/dev/run-tests index d401c90f41d7b..0e24515d1376c 100755 --- a/dev/run-tests +++ b/dev/run-tests @@ -66,16 +66,25 @@ fi set -e set -o pipefail +echo "" echo "=========================================================================" echo "Running Apache RAT checks" echo "=========================================================================" dev/check-license +echo "" echo "=========================================================================" echo "Running Scala style checks" echo "=========================================================================" -dev/scalastyle +dev/lint-scala +echo "" +echo "=========================================================================" +echo "Running Python style checks" +echo "=========================================================================" +dev/lint-python + +echo "" echo "=========================================================================" echo "Running Spark unit tests" echo "=========================================================================" @@ -89,11 +98,13 @@ fi echo -e "q\n" | sbt/sbt $SBT_MAVEN_PROFILES_ARGS clean package assembly/assembly test | \ grep -v -e "info.*Resolving" -e "warn.*Merging" -e "info.*Including" +echo "" echo "=========================================================================" echo "Running PySpark tests" echo "=========================================================================" ./python/run-tests +echo "" echo "=========================================================================" echo "Detecting binary incompatibilites with MiMa" echo "=========================================================================" diff --git a/dev/scalastyle b/dev/scalastyle index d9f2b91a3a091..b53053a04ff42 100755 --- a/dev/scalastyle +++ b/dev/scalastyle @@ -30,5 +30,5 @@ if test ! -z "$ERRORS"; then echo -e "Scalastyle checks failed at following occurrences:\n$ERRORS" exit 1 else - echo -e "Scalastyle checks passed.\n" + echo -e "Scalastyle checks passed." fi diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py index 45d36e5d0e764..f133cf6f7befc 100644 --- a/python/pyspark/accumulators.py +++ b/python/pyspark/accumulators.py @@ -110,6 +110,7 @@ def _deserialize_accumulator(aid, zero_value, accum_param): class Accumulator(object): + """ A shared variable that can be accumulated, i.e., has a commutative and associative "add" operation. Worker tasks on a Spark cluster can add values to an Accumulator with the C{+=} @@ -166,6 +167,7 @@ def __repr__(self): class AccumulatorParam(object): + """ Helper object that defines how to accumulate values of a given type. """ @@ -186,6 +188,7 @@ def addInPlace(self, value1, value2): class AddingAccumulatorParam(AccumulatorParam): + """ An AccumulatorParam that uses the + operators to add values. Designed for simple types such as integers, floats, and lists. Requires the zero value for the underlying type @@ -210,6 +213,7 @@ def addInPlace(self, value1, value2): class _UpdateRequestHandler(SocketServer.StreamRequestHandler): + """ This handler will keep polling updates from the same socket until the server is shutdown. @@ -228,7 +232,9 @@ def handle(self): # Write a byte in acknowledgement self.wfile.write(struct.pack("!b", 1)) + class AccumulatorServer(SocketServer.TCPServer): + """ A simple TCP server that intercepts shutdown() in order to interrupt our continuous polling on the handler. @@ -239,6 +245,7 @@ def shutdown(self): self.server_shutdown = True SocketServer.TCPServer.shutdown(self) + def _start_update_server(): """Start a TCP server to receive accumulator updates in a daemon thread, and returns it""" server = AccumulatorServer(("localhost", 0), _UpdateRequestHandler) diff --git a/python/pyspark/broadcast.py b/python/pyspark/broadcast.py index 43f40f8783bfd..f3e64989ed564 100644 --- a/python/pyspark/broadcast.py +++ b/python/pyspark/broadcast.py @@ -45,6 +45,7 @@ def _from_id(bid): class Broadcast(object): + """ A broadcast variable created with L{SparkContext.broadcast()}. diff --git a/python/pyspark/conf.py b/python/pyspark/conf.py index b4c82f519bd53..fb716f6753a45 100644 --- a/python/pyspark/conf.py +++ b/python/pyspark/conf.py @@ -56,6 +56,7 @@ class SparkConf(object): + """ Configuration for a Spark application. Used to set various Spark parameters as key-value pairs. diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 2e80eb50f2207..4001ecab5ea00 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -47,6 +47,7 @@ class SparkContext(object): + """ Main entry point for Spark functionality. A SparkContext represents the connection to a Spark cluster, and can be used to create L{RDD}s and @@ -213,7 +214,7 @@ def _ensure_initialized(cls, instance=None, gateway=None): if instance: if (SparkContext._active_spark_context and - SparkContext._active_spark_context != instance): + SparkContext._active_spark_context != instance): currentMaster = SparkContext._active_spark_context.master currentAppName = SparkContext._active_spark_context.appName callsite = SparkContext._active_spark_context._callsite @@ -406,7 +407,7 @@ def sequenceFile(self, path, keyClass=None, valueClass=None, keyConverter=None, batchSize = max(1, batchSize or self._default_batch_size_for_serialized_input) ser = BatchedSerializer(PickleSerializer()) if (batchSize > 1) else PickleSerializer() jrdd = self._jvm.PythonRDD.sequenceFile(self._jsc, path, keyClass, valueClass, - keyConverter, valueConverter, minSplits, batchSize) + keyConverter, valueConverter, minSplits, batchSize) return RDD(jrdd, self, ser) def newAPIHadoopFile(self, path, inputFormatClass, keyClass, valueClass, keyConverter=None, @@ -437,7 +438,8 @@ def newAPIHadoopFile(self, path, inputFormatClass, keyClass, valueClass, keyConv batchSize = max(1, batchSize or self._default_batch_size_for_serialized_input) ser = BatchedSerializer(PickleSerializer()) if (batchSize > 1) else PickleSerializer() jrdd = self._jvm.PythonRDD.newAPIHadoopFile(self._jsc, path, inputFormatClass, keyClass, - valueClass, keyConverter, valueConverter, jconf, batchSize) + valueClass, keyConverter, valueConverter, + jconf, batchSize) return RDD(jrdd, self, ser) def newAPIHadoopRDD(self, inputFormatClass, keyClass, valueClass, keyConverter=None, @@ -465,7 +467,8 @@ def newAPIHadoopRDD(self, inputFormatClass, keyClass, valueClass, keyConverter=N batchSize = max(1, batchSize or self._default_batch_size_for_serialized_input) ser = BatchedSerializer(PickleSerializer()) if (batchSize > 1) else PickleSerializer() jrdd = self._jvm.PythonRDD.newAPIHadoopRDD(self._jsc, inputFormatClass, keyClass, - valueClass, keyConverter, valueConverter, jconf, batchSize) + valueClass, keyConverter, valueConverter, + jconf, batchSize) return RDD(jrdd, self, ser) def hadoopFile(self, path, inputFormatClass, keyClass, valueClass, keyConverter=None, @@ -496,7 +499,8 @@ def hadoopFile(self, path, inputFormatClass, keyClass, valueClass, keyConverter= batchSize = max(1, batchSize or self._default_batch_size_for_serialized_input) ser = BatchedSerializer(PickleSerializer()) if (batchSize > 1) else PickleSerializer() jrdd = self._jvm.PythonRDD.hadoopFile(self._jsc, path, inputFormatClass, keyClass, - valueClass, keyConverter, valueConverter, jconf, batchSize) + valueClass, keyConverter, valueConverter, + jconf, batchSize) return RDD(jrdd, self, ser) def hadoopRDD(self, inputFormatClass, keyClass, valueClass, keyConverter=None, @@ -523,8 +527,9 @@ def hadoopRDD(self, inputFormatClass, keyClass, valueClass, keyConverter=None, jconf = self._dictToJavaMap(conf) batchSize = max(1, batchSize or self._default_batch_size_for_serialized_input) ser = BatchedSerializer(PickleSerializer()) if (batchSize > 1) else PickleSerializer() - jrdd = self._jvm.PythonRDD.hadoopRDD(self._jsc, inputFormatClass, keyClass, valueClass, - keyConverter, valueConverter, jconf, batchSize) + jrdd = self._jvm.PythonRDD.hadoopRDD(self._jsc, inputFormatClass, keyClass, + valueClass, keyConverter, valueConverter, + jconf, batchSize) return RDD(jrdd, self, ser) def _checkpointFile(self, name, input_deserializer): @@ -555,8 +560,7 @@ def union(self, rdds): first = rdds[0]._jrdd rest = [x._jrdd for x in rdds[1:]] rest = ListConverter().convert(rest, self._gateway._gateway_client) - return RDD(self._jsc.union(first, rest), self, - rdds[0]._jrdd_deserializer) + return RDD(self._jsc.union(first, rest), self, rdds[0]._jrdd_deserializer) def broadcast(self, value): """ @@ -568,8 +572,7 @@ def broadcast(self, value): pickleSer = PickleSerializer() pickled = pickleSer.dumps(value) jbroadcast = self._jsc.broadcast(bytearray(pickled)) - return Broadcast(jbroadcast.id(), value, jbroadcast, - self._pickled_broadcast_vars) + return Broadcast(jbroadcast.id(), value, jbroadcast, self._pickled_broadcast_vars) def accumulator(self, value, accum_param=None): """ diff --git a/python/pyspark/daemon.py b/python/pyspark/daemon.py index b00da833d06f1..e73538baf0b93 100644 --- a/python/pyspark/daemon.py +++ b/python/pyspark/daemon.py @@ -43,7 +43,7 @@ def worker(sock): """ # Redirect stdout to stderr os.dup2(2, 1) - sys.stdout = sys.stderr # The sys.stdout object is different from file descriptor 1 + sys.stdout = sys.stderr # The sys.stdout object is different from file descriptor 1 signal.signal(SIGHUP, SIG_DFL) signal.signal(SIGCHLD, SIG_DFL) @@ -134,8 +134,7 @@ def handle_sigchld(*args): try: os.kill(worker_pid, signal.SIGKILL) except OSError: - pass # process already died - + pass # process already died if listen_sock in ready_fds: sock, addr = listen_sock.accept() diff --git a/python/pyspark/files.py b/python/pyspark/files.py index 57ee14eeb7776..331de9a9b2212 100644 --- a/python/pyspark/files.py +++ b/python/pyspark/files.py @@ -19,6 +19,7 @@ class SparkFiles(object): + """ Resolves paths to files added through L{SparkContext.addFile()}. diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py index 2c129679f47f3..37386ab0d7d49 100644 --- a/python/pyspark/java_gateway.py +++ b/python/pyspark/java_gateway.py @@ -65,6 +65,7 @@ def preexec_func(): # Create a thread to echo output from the GatewayServer, which is required # for Java log output to show up: class EchoOutputThread(Thread): + def __init__(self, stream): Thread.__init__(self) self.daemon = True diff --git a/python/pyspark/mllib/_common.py b/python/pyspark/mllib/_common.py index 9c1565affbdac..db341da85f865 100644 --- a/python/pyspark/mllib/_common.py +++ b/python/pyspark/mllib/_common.py @@ -72,9 +72,9 @@ # Python interpreter must agree on what endian the machine is. -DENSE_VECTOR_MAGIC = 1 +DENSE_VECTOR_MAGIC = 1 SPARSE_VECTOR_MAGIC = 2 -DENSE_MATRIX_MAGIC = 3 +DENSE_MATRIX_MAGIC = 3 LABELED_POINT_MAGIC = 4 @@ -443,6 +443,7 @@ def _serialize_rating(r): class RatingDeserializer(Serializer): + def loads(self, stream): length = struct.unpack("!i", stream.read(4))[0] ba = stream.read(length) diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py index 5ec1a8084d269..ffdda7ee19302 100644 --- a/python/pyspark/mllib/classification.py +++ b/python/pyspark/mllib/classification.py @@ -31,6 +31,7 @@ class LogisticRegressionModel(LinearModel): + """A linear binary classification model derived from logistic regression. >>> data = [ @@ -60,6 +61,7 @@ class LogisticRegressionModel(LinearModel): >>> lrm.predict(SparseVector(2, {1: 0.0})) <= 0 True """ + def predict(self, x): _linear_predictor_typecheck(x, self._coeff) margin = _dot(x, self._coeff) + self._intercept @@ -72,6 +74,7 @@ def predict(self, x): class LogisticRegressionWithSGD(object): + @classmethod def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, initialWeights=None, regParam=1.0, regType=None, intercept=False): @@ -108,6 +111,7 @@ def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, class SVMModel(LinearModel): + """A support vector machine. >>> data = [ @@ -131,6 +135,7 @@ class SVMModel(LinearModel): >>> svm.predict(SparseVector(2, {0: -1.0})) <= 0 True """ + def predict(self, x): _linear_predictor_typecheck(x, self._coeff) margin = _dot(x, self._coeff) + self._intercept @@ -138,6 +143,7 @@ def predict(self, x): class SVMWithSGD(object): + @classmethod def train(cls, data, iterations=100, step=1.0, regParam=1.0, miniBatchFraction=1.0, initialWeights=None, regType=None, intercept=False): @@ -173,6 +179,7 @@ def train(cls, data, iterations=100, step=1.0, regParam=1.0, class NaiveBayesModel(object): + """ Model for Naive Bayes classifiers. @@ -213,6 +220,7 @@ def predict(self, x): class NaiveBayes(object): + @classmethod def train(cls, data, lambda_=1.0): """ diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py index b380e8f6c8725..a0630d1d5c58b 100644 --- a/python/pyspark/mllib/clustering.py +++ b/python/pyspark/mllib/clustering.py @@ -27,6 +27,7 @@ class KMeansModel(object): + """A clustering model derived from the k-means method. >>> data = array([0.0,0.0, 1.0,1.0, 9.0,8.0, 8.0,9.0]).reshape(4,2) @@ -55,6 +56,7 @@ class KMeansModel(object): >>> type(model.clusterCenters) """ + def __init__(self, centers): self.centers = centers @@ -76,6 +78,7 @@ def predict(self, x): class KMeans(object): + @classmethod def train(cls, data, k, maxIterations=100, runs=1, initializationMode="k-means||"): """Train a k-means clustering model.""" diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg.py index 54720c2324ca6..9a239abfbbeb1 100644 --- a/python/pyspark/mllib/linalg.py +++ b/python/pyspark/mllib/linalg.py @@ -27,6 +27,7 @@ class SparseVector(object): + """ A simple sparse vector class for passing data to MLlib. Users may alternatively pass SciPy's {scipy.sparse} data types. @@ -192,6 +193,7 @@ def __ne__(self, other): class Vectors(object): + """ Factory methods for working with vectors. Note that dense vectors are simply represented as NumPy array objects, so there is no need diff --git a/python/pyspark/mllib/random.py b/python/pyspark/mllib/random.py index 36e710dbae7a8..eb496688b6eef 100644 --- a/python/pyspark/mllib/random.py +++ b/python/pyspark/mllib/random.py @@ -24,7 +24,9 @@ from pyspark.mllib._common import _deserialize_double, _deserialize_double_vector from pyspark.serializers import NoOpSerializer + class RandomRDDGenerators: + """ Generator methods for creating RDDs comprised of i.i.d samples from some distribution. @@ -53,7 +55,7 @@ def uniformRDD(sc, size, numPartitions=None, seed=None): True """ jrdd = sc._jvm.PythonMLLibAPI().uniformRDD(sc._jsc, size, numPartitions, seed) - uniform = RDD(jrdd, sc, NoOpSerializer()) + uniform = RDD(jrdd, sc, NoOpSerializer()) return uniform.map(lambda bytes: _deserialize_double(bytearray(bytes))) @staticmethod @@ -77,7 +79,7 @@ def normalRDD(sc, size, numPartitions=None, seed=None): True """ jrdd = sc._jvm.PythonMLLibAPI().normalRDD(sc._jsc, size, numPartitions, seed) - normal = RDD(jrdd, sc, NoOpSerializer()) + normal = RDD(jrdd, sc, NoOpSerializer()) return normal.map(lambda bytes: _deserialize_double(bytearray(bytes))) @staticmethod @@ -98,7 +100,7 @@ def poissonRDD(sc, mean, size, numPartitions=None, seed=None): True """ jrdd = sc._jvm.PythonMLLibAPI().poissonRDD(sc._jsc, mean, size, numPartitions, seed) - poisson = RDD(jrdd, sc, NoOpSerializer()) + poisson = RDD(jrdd, sc, NoOpSerializer()) return poisson.map(lambda bytes: _deserialize_double(bytearray(bytes))) @staticmethod @@ -118,7 +120,7 @@ def uniformVectorRDD(sc, numRows, numCols, numPartitions=None, seed=None): """ jrdd = sc._jvm.PythonMLLibAPI() \ .uniformVectorRDD(sc._jsc, numRows, numCols, numPartitions, seed) - uniform = RDD(jrdd, sc, NoOpSerializer()) + uniform = RDD(jrdd, sc, NoOpSerializer()) return uniform.map(lambda bytes: _deserialize_double_vector(bytearray(bytes))) @staticmethod @@ -138,7 +140,7 @@ def normalVectorRDD(sc, numRows, numCols, numPartitions=None, seed=None): """ jrdd = sc._jvm.PythonMLLibAPI() \ .normalVectorRDD(sc._jsc, numRows, numCols, numPartitions, seed) - normal = RDD(jrdd, sc, NoOpSerializer()) + normal = RDD(jrdd, sc, NoOpSerializer()) return normal.map(lambda bytes: _deserialize_double_vector(bytearray(bytes))) @staticmethod @@ -161,7 +163,7 @@ def poissonVectorRDD(sc, mean, numRows, numCols, numPartitions=None, seed=None): """ jrdd = sc._jvm.PythonMLLibAPI() \ .poissonVectorRDD(sc._jsc, mean, numRows, numCols, numPartitions, seed) - poisson = RDD(jrdd, sc, NoOpSerializer()) + poisson = RDD(jrdd, sc, NoOpSerializer()) return poisson.map(lambda bytes: _deserialize_double_vector(bytearray(bytes))) diff --git a/python/pyspark/mllib/recommendation.py b/python/pyspark/mllib/recommendation.py index 6c385042ffa5f..e863fc249ec36 100644 --- a/python/pyspark/mllib/recommendation.py +++ b/python/pyspark/mllib/recommendation.py @@ -26,6 +26,7 @@ class MatrixFactorizationModel(object): + """A matrix factorisation model trained by regularized alternating least-squares. @@ -58,6 +59,7 @@ def predictAll(self, usersProducts): class ALS(object): + @classmethod def train(cls, ratings, rank, iterations=5, lambda_=0.01, blocks=-1): sc = ratings.context diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py index 041b119269427..d8792cf44872f 100644 --- a/python/pyspark/mllib/regression.py +++ b/python/pyspark/mllib/regression.py @@ -27,6 +27,7 @@ class LabeledPoint(object): + """ The features and labels of a data point. @@ -34,6 +35,7 @@ class LabeledPoint(object): @param features: Vector of features for this point (NumPy array, list, pyspark.mllib.linalg.SparseVector, or scipy.sparse column matrix) """ + def __init__(self, label, features): self.label = label if (type(features) == ndarray or type(features) == SparseVector @@ -49,7 +51,9 @@ def __str__(self): class LinearModel(object): + """A linear model that has a vector of coefficients and an intercept.""" + def __init__(self, weights, intercept): self._coeff = weights self._intercept = intercept @@ -64,6 +68,7 @@ def intercept(self): class LinearRegressionModelBase(LinearModel): + """A linear regression model. >>> lrmb = LinearRegressionModelBase(array([1.0, 2.0]), 0.1) @@ -72,6 +77,7 @@ class LinearRegressionModelBase(LinearModel): >>> abs(lrmb.predict(SparseVector(2, {0: -1.03, 1: 7.777})) - 14.624) < 1e-6 True """ + def predict(self, x): """Predict the value of the dependent variable given a vector x""" """containing values for the independent variables.""" @@ -80,6 +86,7 @@ def predict(self, x): class LinearRegressionModel(LinearRegressionModelBase): + """A linear regression model derived from a least-squares fit. >>> from pyspark.mllib.regression import LabeledPoint @@ -111,6 +118,7 @@ class LinearRegressionModel(LinearRegressionModelBase): class LinearRegressionWithSGD(object): + @classmethod def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, initialWeights=None, regParam=1.0, regType=None, intercept=False): @@ -146,6 +154,7 @@ def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, class LassoModel(LinearRegressionModelBase): + """A linear regression model derived from a least-squares fit with an l_1 penalty term. @@ -178,6 +187,7 @@ class LassoModel(LinearRegressionModelBase): class LassoWithSGD(object): + @classmethod def train(cls, data, iterations=100, step=1.0, regParam=1.0, miniBatchFraction=1.0, initialWeights=None): @@ -189,6 +199,7 @@ def train(cls, data, iterations=100, step=1.0, regParam=1.0, class RidgeRegressionModel(LinearRegressionModelBase): + """A linear regression model derived from a least-squares fit with an l_2 penalty term. @@ -221,6 +232,7 @@ class RidgeRegressionModel(LinearRegressionModelBase): class RidgeRegressionWithSGD(object): + @classmethod def train(cls, data, iterations=100, step=1.0, regParam=1.0, miniBatchFraction=1.0, initialWeights=None): diff --git a/python/pyspark/mllib/stat.py b/python/pyspark/mllib/stat.py index 0a08a562d1f1f..982906b9d09f0 100644 --- a/python/pyspark/mllib/stat.py +++ b/python/pyspark/mllib/stat.py @@ -24,6 +24,7 @@ _serialize_double, _serialize_double_vector, \ _deserialize_double, _deserialize_double_matrix + class Statistics(object): @staticmethod diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 9d1e5be637a9a..6f3ec8ac94bac 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -39,6 +39,7 @@ class VectorTests(unittest.TestCase): + def test_serialize(self): sv = SparseVector(4, {1: 1, 3: 2}) dv = array([1., 2., 3., 4.]) @@ -81,6 +82,7 @@ def test_squared_distance(self): class ListTests(PySparkTestCase): + """ Test MLlib algorithms on plain lists, to make sure they're passed through as NumPy arrays. @@ -128,7 +130,7 @@ def test_classification(self): self.assertTrue(nb_model.predict(features[2]) <= 0) self.assertTrue(nb_model.predict(features[3]) > 0) - categoricalFeaturesInfo = {0: 3} # feature 0 has 3 categories + categoricalFeaturesInfo = {0: 3} # feature 0 has 3 categories dt_model = \ DecisionTree.trainClassifier(rdd, numClasses=2, categoricalFeaturesInfo=categoricalFeaturesInfo) @@ -168,7 +170,7 @@ def test_regression(self): self.assertTrue(rr_model.predict(features[2]) <= 0) self.assertTrue(rr_model.predict(features[3]) > 0) - categoricalFeaturesInfo = {0: 2} # feature 0 has 2 categories + categoricalFeaturesInfo = {0: 2} # feature 0 has 2 categories dt_model = \ DecisionTree.trainRegressor(rdd, categoricalFeaturesInfo=categoricalFeaturesInfo) self.assertTrue(dt_model.predict(features[0]) <= 0) @@ -179,6 +181,7 @@ def test_regression(self): @unittest.skipIf(not _have_scipy, "SciPy not installed") class SciPyTests(PySparkTestCase): + """ Test both vector operations and MLlib algorithms with SciPy sparse matrices, if SciPy is available. @@ -276,7 +279,7 @@ def test_classification(self): self.assertTrue(nb_model.predict(features[2]) <= 0) self.assertTrue(nb_model.predict(features[3]) > 0) - categoricalFeaturesInfo = {0: 3} # feature 0 has 3 categories + categoricalFeaturesInfo = {0: 3} # feature 0 has 3 categories dt_model = DecisionTree.trainClassifier(rdd, numClasses=2, categoricalFeaturesInfo=categoricalFeaturesInfo) self.assertTrue(dt_model.predict(features[0]) <= 0) @@ -315,7 +318,7 @@ def test_regression(self): self.assertTrue(rr_model.predict(features[2]) <= 0) self.assertTrue(rr_model.predict(features[3]) > 0) - categoricalFeaturesInfo = {0: 2} # feature 0 has 2 categories + categoricalFeaturesInfo = {0: 2} # feature 0 has 2 categories dt_model = DecisionTree.trainRegressor(rdd, categoricalFeaturesInfo=categoricalFeaturesInfo) self.assertTrue(dt_model.predict(features[0]) <= 0) self.assertTrue(dt_model.predict(features[1]) > 0) diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py index 1e0006df75ac6..2518001ea0b93 100644 --- a/python/pyspark/mllib/tree.py +++ b/python/pyspark/mllib/tree.py @@ -25,7 +25,9 @@ from pyspark.mllib.regression import LabeledPoint from pyspark.serializers import NoOpSerializer + class DecisionTreeModel(object): + """ A decision tree model for classification or regression. @@ -77,6 +79,7 @@ def __str__(self): class DecisionTree(object): + """ Learning algorithm for a decision tree model for classification or regression. @@ -174,7 +177,6 @@ def trainRegressor(data, categoricalFeaturesInfo={}, categoricalFeaturesInfo, impurity, maxDepth, maxBins) - @staticmethod def train(data, algo, numClasses, categoricalFeaturesInfo, impurity, maxDepth, maxBins=100): diff --git a/python/pyspark/mllib/util.py b/python/pyspark/mllib/util.py index 639cda6350229..4962d05491c03 100644 --- a/python/pyspark/mllib/util.py +++ b/python/pyspark/mllib/util.py @@ -26,6 +26,7 @@ class MLUtils: + """ Helper methods to load, save and pre-process data used in MLlib. """ diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 309f5a9b6038d..30b834d2085cd 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -233,7 +233,7 @@ def __init__(self, jrdd, ctx, jrdd_deserializer): def _toPickleSerialization(self): if (self._jrdd_deserializer == PickleSerializer() or - self._jrdd_deserializer == BatchedSerializer(PickleSerializer())): + self._jrdd_deserializer == BatchedSerializer(PickleSerializer())): return self else: return self._reserialize(BatchedSerializer(PickleSerializer(), 10)) @@ -1079,7 +1079,9 @@ def saveAsNewAPIHadoopFile(self, path, outputFormatClass, keyClass=None, valueCl pickledRDD = self._toPickleSerialization() batched = isinstance(pickledRDD._jrdd_deserializer, BatchedSerializer) self.ctx._jvm.PythonRDD.saveAsNewAPIHadoopFile(pickledRDD._jrdd, batched, path, - outputFormatClass, keyClass, valueClass, keyConverter, valueConverter, jconf) + outputFormatClass, + keyClass, valueClass, + keyConverter, valueConverter, jconf) def saveAsHadoopDataset(self, conf, keyConverter=None, valueConverter=None): """ @@ -1125,8 +1127,10 @@ def saveAsHadoopFile(self, path, outputFormatClass, keyClass=None, valueClass=No pickledRDD = self._toPickleSerialization() batched = isinstance(pickledRDD._jrdd_deserializer, BatchedSerializer) self.ctx._jvm.PythonRDD.saveAsHadoopFile(pickledRDD._jrdd, batched, path, - outputFormatClass, keyClass, valueClass, keyConverter, valueConverter, - jconf, compressionCodecClass) + outputFormatClass, + keyClass, valueClass, + keyConverter, valueConverter, + jconf, compressionCodecClass) def saveAsSequenceFile(self, path, compressionCodecClass=None): """ @@ -1348,7 +1352,7 @@ def partitionBy(self, numPartitions, partitionFunc=portable_hash): outputSerializer = self.ctx._unbatched_serializer limit = (_parse_memory(self.ctx._conf.get( - "spark.python.worker.memory", "512m")) / 2) + "spark.python.worker.memory", "512m")) / 2) def add_shuffle_key(split, iterator): @@ -1430,12 +1434,12 @@ def combineByKey(self, createCombiner, mergeValue, mergeCombiners, spill = (self.ctx._conf.get("spark.shuffle.spill", 'True').lower() == 'true') memory = _parse_memory(self.ctx._conf.get( - "spark.python.worker.memory", "512m")) + "spark.python.worker.memory", "512m")) agg = Aggregator(createCombiner, mergeValue, mergeCombiners) def combineLocally(iterator): merger = ExternalMerger(agg, memory * 0.9, serializer) \ - if spill else InMemoryMerger(agg) + if spill else InMemoryMerger(agg) merger.mergeValues(iterator) return merger.iteritems() @@ -1444,7 +1448,7 @@ def combineLocally(iterator): def _mergeCombiners(iterator): merger = ExternalMerger(agg, memory, serializer) \ - if spill else InMemoryMerger(agg) + if spill else InMemoryMerger(agg) merger.mergeCombiners(iterator) return merger.iteritems() @@ -1588,7 +1592,7 @@ def sampleByKey(self, withReplacement, fractions, seed=None): """ for fraction in fractions.values(): assert fraction >= 0.0, "Negative fraction value: %s" % fraction - return self.mapPartitionsWithIndex( \ + return self.mapPartitionsWithIndex( RDDStratifiedSampler(withReplacement, fractions, seed).func, True) def subtractByKey(self, other, numPartitions=None): diff --git a/python/pyspark/rddsampler.py b/python/pyspark/rddsampler.py index 2df000fdb08ca..55e247da0e4dc 100644 --- a/python/pyspark/rddsampler.py +++ b/python/pyspark/rddsampler.py @@ -20,6 +20,7 @@ class RDDSamplerBase(object): + def __init__(self, withReplacement, seed=None): try: import numpy @@ -95,6 +96,7 @@ def shuffle(self, vals): class RDDSampler(RDDSamplerBase): + def __init__(self, withReplacement, fraction, seed=None): RDDSamplerBase.__init__(self, withReplacement, seed) self._fraction = fraction @@ -113,7 +115,9 @@ def func(self, split, iterator): if self.getUniformSample(split) <= self._fraction: yield obj + class RDDStratifiedSampler(RDDSamplerBase): + def __init__(self, withReplacement, fractions, seed=None): RDDSamplerBase.__init__(self, withReplacement, seed) self._fractions = fractions diff --git a/python/pyspark/resultiterable.py b/python/pyspark/resultiterable.py index df34740fc8176..ef04c82866e6c 100644 --- a/python/pyspark/resultiterable.py +++ b/python/pyspark/resultiterable.py @@ -21,9 +21,11 @@ class ResultIterable(collections.Iterable): + """ A special result iterable. This is used because the standard iterator can not be pickled """ + def __init__(self, data): self.data = data self.index = 0 diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index a10f85b55ad30..b35558db3e007 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -111,6 +111,7 @@ def __ne__(self, other): class FramedSerializer(Serializer): + """ Serializer that writes objects as a stream of (length, data) pairs, where C{length} is a 32-bit integer and data is C{length} bytes. @@ -162,6 +163,7 @@ def loads(self, obj): class BatchedSerializer(Serializer): + """ Serializes a stream of objects in batches by calling its wrapped Serializer with streams of objects. @@ -207,6 +209,7 @@ def __str__(self): class CartesianDeserializer(FramedSerializer): + """ Deserializes the JavaRDD cartesian() of two PythonRDDs. """ @@ -240,6 +243,7 @@ def __str__(self): class PairDeserializer(CartesianDeserializer): + """ Deserializes the JavaRDD zip() of two PythonRDDs. """ @@ -289,6 +293,7 @@ def _hack_namedtuple(cls): """ Make class generated by namedtuple picklable """ name = cls.__name__ fields = cls._fields + def __reduce__(self): return (_restore, (name, fields, tuple(self))) cls.__reduce__ = __reduce__ @@ -301,10 +306,11 @@ def _hijack_namedtuple(): if hasattr(collections.namedtuple, "__hijack"): return - global _old_namedtuple # or it will put in closure + global _old_namedtuple # or it will put in closure + def _copy_func(f): return types.FunctionType(f.func_code, f.func_globals, f.func_name, - f.func_defaults, f.func_closure) + f.func_defaults, f.func_closure) _old_namedtuple = _copy_func(collections.namedtuple) @@ -323,15 +329,16 @@ def namedtuple(name, fields, verbose=False, rename=False): # so only hack those in __main__ module for n, o in sys.modules["__main__"].__dict__.iteritems(): if (type(o) is type and o.__base__ is tuple - and hasattr(o, "_fields") - and "__reduce__" not in o.__dict__): - _hack_namedtuple(o) # hack inplace + and hasattr(o, "_fields") + and "__reduce__" not in o.__dict__): + _hack_namedtuple(o) # hack inplace _hijack_namedtuple() class PickleSerializer(FramedSerializer): + """ Serializes objects using Python's cPickle serializer: @@ -354,6 +361,7 @@ def dumps(self, obj): class MarshalSerializer(FramedSerializer): + """ Serializes objects using Python's Marshal serializer: @@ -367,9 +375,11 @@ class MarshalSerializer(FramedSerializer): class AutoSerializer(FramedSerializer): + """ Choose marshal or cPickle as serialization protocol autumatically """ + def __init__(self): FramedSerializer.__init__(self) self._type = None @@ -394,6 +404,7 @@ def loads(self, obj): class UTF8Deserializer(Serializer): + """ Deserializes streams written by String.getBytes. """ diff --git a/python/pyspark/shuffle.py b/python/pyspark/shuffle.py index e3923d1c36c57..2c68cd4921deb 100644 --- a/python/pyspark/shuffle.py +++ b/python/pyspark/shuffle.py @@ -45,7 +45,7 @@ def get_used_memory(): return int(line.split()[1]) >> 10 else: warnings.warn("Please install psutil to have better " - "support with spilling") + "support with spilling") if platform.system() == "Darwin": import resource rss = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss @@ -141,7 +141,7 @@ class ExternalMerger(Merger): This class works as follows: - - It repeatedly combine the items and save them in one dict in + - It repeatedly combine the items and save them in one dict in memory. - When the used memory goes above memory limit, it will split @@ -190,12 +190,12 @@ class ExternalMerger(Merger): MAX_TOTAL_PARTITIONS = 4096 def __init__(self, aggregator, memory_limit=512, serializer=None, - localdirs=None, scale=1, partitions=59, batch=1000): + localdirs=None, scale=1, partitions=59, batch=1000): Merger.__init__(self, aggregator) self.memory_limit = memory_limit # default serializer is only used for tests self.serializer = serializer or \ - BatchedSerializer(PickleSerializer(), 1024) + BatchedSerializer(PickleSerializer(), 1024) self.localdirs = localdirs or self._get_dirs() # number of partitions when spill data into disks self.partitions = partitions @@ -341,7 +341,7 @@ def _spill(self): self.pdata[i].clear() self.spills += 1 - gc.collect() # release the memory as much as possible + gc.collect() # release the memory as much as possible def iteritems(self): """ Return all merged items as iterator """ @@ -370,8 +370,8 @@ def _external_items(self): if (self.scale * self.partitions < self.MAX_TOTAL_PARTITIONS and j < self.spills - 1 and get_used_memory() > hard_limit): - self.data.clear() # will read from disk again - gc.collect() # release the memory as much as possible + self.data.clear() # will read from disk again + gc.collect() # release the memory as much as possible for v in self._recursive_merged_items(i): yield v return @@ -409,9 +409,9 @@ def _recursive_merged_items(self, start): for i in range(start, self.partitions): subdirs = [os.path.join(d, "parts", str(i)) - for d in self.localdirs] + for d in self.localdirs] m = ExternalMerger(self.agg, self.memory_limit, self.serializer, - subdirs, self.scale * self.partitions) + subdirs, self.scale * self.partitions) m.pdata = [{} for _ in range(self.partitions)] limit = self._next_limit() @@ -419,7 +419,7 @@ def _recursive_merged_items(self, start): path = self._get_spill_dir(j) p = os.path.join(path, str(i)) m._partitioned_mergeCombiners( - self.serializer.load_stream(open(p))) + self.serializer.load_stream(open(p))) if get_used_memory() > limit: m._spill() diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index adc56e7ec0e2b..950e275adbf01 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -45,6 +45,7 @@ class DataType(object): + """Spark SQL DataType""" def __repr__(self): @@ -62,6 +63,7 @@ def __ne__(self, other): class PrimitiveTypeSingleton(type): + """Metaclass for PrimitiveType""" _instances = {} @@ -73,6 +75,7 @@ def __call__(cls): class PrimitiveType(DataType): + """Spark SQL PrimitiveType""" __metaclass__ = PrimitiveTypeSingleton @@ -83,6 +86,7 @@ def __eq__(self, other): class StringType(PrimitiveType): + """Spark SQL StringType The data type representing string values. @@ -90,6 +94,7 @@ class StringType(PrimitiveType): class BinaryType(PrimitiveType): + """Spark SQL BinaryType The data type representing bytearray values. @@ -97,6 +102,7 @@ class BinaryType(PrimitiveType): class BooleanType(PrimitiveType): + """Spark SQL BooleanType The data type representing bool values. @@ -104,6 +110,7 @@ class BooleanType(PrimitiveType): class TimestampType(PrimitiveType): + """Spark SQL TimestampType The data type representing datetime.datetime values. @@ -111,6 +118,7 @@ class TimestampType(PrimitiveType): class DecimalType(PrimitiveType): + """Spark SQL DecimalType The data type representing decimal.Decimal values. @@ -118,6 +126,7 @@ class DecimalType(PrimitiveType): class DoubleType(PrimitiveType): + """Spark SQL DoubleType The data type representing float values. @@ -125,6 +134,7 @@ class DoubleType(PrimitiveType): class FloatType(PrimitiveType): + """Spark SQL FloatType The data type representing single precision floating-point values. @@ -132,6 +142,7 @@ class FloatType(PrimitiveType): class ByteType(PrimitiveType): + """Spark SQL ByteType The data type representing int values with 1 singed byte. @@ -139,6 +150,7 @@ class ByteType(PrimitiveType): class IntegerType(PrimitiveType): + """Spark SQL IntegerType The data type representing int values. @@ -146,6 +158,7 @@ class IntegerType(PrimitiveType): class LongType(PrimitiveType): + """Spark SQL LongType The data type representing long values. If the any value is @@ -155,6 +168,7 @@ class LongType(PrimitiveType): class ShortType(PrimitiveType): + """Spark SQL ShortType The data type representing int values with 2 signed bytes. @@ -162,6 +176,7 @@ class ShortType(PrimitiveType): class ArrayType(DataType): + """Spark SQL ArrayType The data type representing list values. An ArrayType object @@ -187,10 +202,11 @@ def __init__(self, elementType, containsNull=False): def __str__(self): return "ArrayType(%s,%s)" % (self.elementType, - str(self.containsNull).lower()) + str(self.containsNull).lower()) class MapType(DataType): + """Spark SQL MapType The data type representing dict values. A MapType object comprises @@ -226,10 +242,11 @@ def __init__(self, keyType, valueType, valueContainsNull=True): def __repr__(self): return "MapType(%s,%s,%s)" % (self.keyType, self.valueType, - str(self.valueContainsNull).lower()) + str(self.valueContainsNull).lower()) class StructField(DataType): + """Spark SQL StructField Represents a field in a StructType. @@ -263,10 +280,11 @@ def __init__(self, name, dataType, nullable): def __repr__(self): return "StructField(%s,%s,%s)" % (self.name, self.dataType, - str(self.nullable).lower()) + str(self.nullable).lower()) class StructType(DataType): + """Spark SQL StructType The data type representing rows. @@ -291,7 +309,7 @@ def __init__(self, fields): def __repr__(self): return ("StructType(List(%s))" % - ",".join(str(field) for field in self.fields)) + ",".join(str(field) for field in self.fields)) def _parse_datatype_list(datatype_list_string): @@ -319,7 +337,7 @@ def _parse_datatype_list(datatype_list_string): _all_primitive_types = dict((k, v) for k, v in globals().iteritems() - if type(v) is PrimitiveTypeSingleton and v.__base__ == PrimitiveType) + if type(v) is PrimitiveTypeSingleton and v.__base__ == PrimitiveType) def _parse_datatype_string(datatype_string): @@ -459,16 +477,16 @@ def _infer_schema(row): items = sorted(row.items()) elif isinstance(row, tuple): - if hasattr(row, "_fields"): # namedtuple + if hasattr(row, "_fields"): # namedtuple items = zip(row._fields, tuple(row)) - elif hasattr(row, "__FIELDS__"): # Row + elif hasattr(row, "__FIELDS__"): # Row items = zip(row.__FIELDS__, tuple(row)) elif all(isinstance(x, tuple) and len(x) == 2 for x in row): items = row else: raise ValueError("Can't infer schema from tuple") - elif hasattr(row, "__dict__"): # object + elif hasattr(row, "__dict__"): # object items = sorted(row.__dict__.items()) else: @@ -499,7 +517,7 @@ def _create_converter(obj, dataType): conv = lambda o: tuple(o.get(n) for n in names) elif isinstance(obj, tuple): - if hasattr(obj, "_fields"): # namedtuple + if hasattr(obj, "_fields"): # namedtuple conv = tuple elif hasattr(obj, "__FIELDS__"): conv = tuple @@ -508,7 +526,7 @@ def _create_converter(obj, dataType): else: raise ValueError("unexpected tuple") - elif hasattr(obj, "__dict__"): # object + elif hasattr(obj, "__dict__"): # object conv = lambda o: [o.__dict__.get(n, None) for n in names] nested = any(_has_struct(f.dataType) for f in dataType.fields) @@ -660,7 +678,7 @@ def _infer_schema_type(obj, dataType): assert len(fs) == len(obj), \ "Obj(%s) have different length with fields(%s)" % (obj, fs) fields = [StructField(f.name, _infer_schema_type(o, f.dataType), True) - for o, f in zip(obj, fs)] + for o, f in zip(obj, fs)] return StructType(fields) else: @@ -683,6 +701,7 @@ def _infer_schema_type(obj, dataType): StructType: (tuple, list), } + def _verify_type(obj, dataType): """ Verify the type of obj against dataType, raise an exception if @@ -728,7 +747,7 @@ def _verify_type(obj, dataType): elif isinstance(dataType, StructType): if len(obj) != len(dataType.fields): raise ValueError("Length of object (%d) does not match with" - "length of fields (%d)" % (len(obj), len(dataType.fields))) + "length of fields (%d)" % (len(obj), len(dataType.fields))) for v, f in zip(obj, dataType.fields): _verify_type(v, f.dataType) @@ -861,6 +880,7 @@ def __reduce__(self): raise Exception("unexpected data type: %s" % dataType) class Row(tuple): + """ Row in SchemaRDD """ __DATATYPE__ = dataType __FIELDS__ = tuple(f.name for f in dataType.fields) @@ -872,7 +892,7 @@ class Row(tuple): def __repr__(self): # call collect __repr__ for nested objects return ("Row(%s)" % ", ".join("%s=%r" % (n, getattr(self, n)) - for n in self.__FIELDS__)) + for n in self.__FIELDS__)) def __reduce__(self): return (_restore_object, (self.__DATATYPE__, tuple(self))) @@ -881,6 +901,7 @@ def __reduce__(self): class SQLContext: + """Main entry point for SparkSQL functionality. A SQLContext can be used create L{SchemaRDD}s, register L{SchemaRDD}s as @@ -960,7 +981,7 @@ def registerFunction(self, name, f, returnType=StringType()): env = MapConverter().convert(self._sc.environment, self._sc._gateway._gateway_client) includes = ListConverter().convert(self._sc._python_includes, - self._sc._gateway._gateway_client) + self._sc._gateway._gateway_client) self._ssql_ctx.registerPython(name, bytearray(CloudPickleSerializer().dumps(command)), env, @@ -1012,7 +1033,7 @@ def inferSchema(self, rdd): first = rdd.first() if not first: raise ValueError("The first row in RDD is empty, " - "can not infer schema") + "can not infer schema") if type(first) is dict: warnings.warn("Using RDD of dict to inferSchema is deprecated") @@ -1287,6 +1308,7 @@ def uncacheTable(self, tableName): class HiveContext(SQLContext): + """A variant of Spark SQL that integrates with data stored in Hive. Configuration for Hive is read from hive-site.xml on the classpath. @@ -1327,6 +1349,7 @@ def hql(self, hqlQuery): class LocalHiveContext(HiveContext): + """Starts up an instance of hive where metadata is stored locally. An in-process metadata data is created with data stored in ./metadata. @@ -1357,7 +1380,7 @@ class LocalHiveContext(HiveContext): def __init__(self, sparkContext, sqlContext=None): HiveContext.__init__(self, sparkContext, sqlContext) warnings.warn("LocalHiveContext is deprecated. " - "Use HiveContext instead.", DeprecationWarning) + "Use HiveContext instead.", DeprecationWarning) def _get_hive_ctx(self): return self._jvm.LocalHiveContext(self._jsc.sc()) @@ -1376,6 +1399,7 @@ def _create_row(fields, values): class Row(tuple): + """ A row in L{SchemaRDD}. The fields in it can be accessed like attributes. @@ -1417,7 +1441,6 @@ def __new__(self, *args, **kwargs): else: raise ValueError("No args or kwargs") - # let obect acs like class def __call__(self, *args): """create new Row object""" @@ -1443,12 +1466,13 @@ def __reduce__(self): def __repr__(self): if hasattr(self, "__FIELDS__"): return "Row(%s)" % ", ".join("%s=%r" % (k, v) - for k, v in zip(self.__FIELDS__, self)) + for k, v in zip(self.__FIELDS__, self)) else: return "" % ", ".join(self) class SchemaRDD(RDD): + """An RDD of L{Row} objects that has an associated schema. The underlying JVM object is a SchemaRDD, not a PythonRDD, so we can @@ -1659,7 +1683,7 @@ def subtract(self, other, numPartitions=None): rdd = self._jschema_rdd.subtract(other._jschema_rdd) else: rdd = self._jschema_rdd.subtract(other._jschema_rdd, - numPartitions) + numPartitions) return SchemaRDD(rdd, self.sql_ctx) else: raise ValueError("Can only subtract another SchemaRDD") @@ -1686,9 +1710,9 @@ def _test(): jsonStrings = [ '{"field1": 1, "field2": "row1", "field3":{"field4":11}}', '{"field1" : 2, "field3":{"field4":22, "field5": [10, 11]},' - '"field6":[{"field7": "row2"}]}', + '"field6":[{"field7": "row2"}]}', '{"field1" : null, "field2": "row3", ' - '"field3":{"field4":33, "field5": []}}' + '"field3":{"field4":33, "field5": []}}' ] globs['jsonStrings'] = jsonStrings globs['json'] = sc.parallelize(jsonStrings) diff --git a/python/pyspark/storagelevel.py b/python/pyspark/storagelevel.py index 5d77a131f2856..2aa0fb9d2c1ed 100644 --- a/python/pyspark/storagelevel.py +++ b/python/pyspark/storagelevel.py @@ -19,6 +19,7 @@ class StorageLevel: + """ Flags for controlling the storage of an RDD. Each StorageLevel records whether to use memory, whether to drop the RDD to disk if it falls out of memory, whether to keep the data in memory diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 4ac94ba729d35..88a61176e51ab 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -62,53 +62,53 @@ def setUp(self): self.N = 1 << 16 self.l = [i for i in xrange(self.N)] self.data = zip(self.l, self.l) - self.agg = Aggregator(lambda x: [x], - lambda x, y: x.append(y) or x, - lambda x, y: x.extend(y) or x) + self.agg = Aggregator(lambda x: [x], + lambda x, y: x.append(y) or x, + lambda x, y: x.extend(y) or x) def test_in_memory(self): m = InMemoryMerger(self.agg) m.mergeValues(self.data) self.assertEqual(sum(sum(v) for k, v in m.iteritems()), - sum(xrange(self.N))) + sum(xrange(self.N))) m = InMemoryMerger(self.agg) m.mergeCombiners(map(lambda (x, y): (x, [y]), self.data)) self.assertEqual(sum(sum(v) for k, v in m.iteritems()), - sum(xrange(self.N))) + sum(xrange(self.N))) def test_small_dataset(self): m = ExternalMerger(self.agg, 1000) m.mergeValues(self.data) self.assertEqual(m.spills, 0) self.assertEqual(sum(sum(v) for k, v in m.iteritems()), - sum(xrange(self.N))) + sum(xrange(self.N))) m = ExternalMerger(self.agg, 1000) m.mergeCombiners(map(lambda (x, y): (x, [y]), self.data)) self.assertEqual(m.spills, 0) self.assertEqual(sum(sum(v) for k, v in m.iteritems()), - sum(xrange(self.N))) + sum(xrange(self.N))) def test_medium_dataset(self): m = ExternalMerger(self.agg, 10) m.mergeValues(self.data) self.assertTrue(m.spills >= 1) self.assertEqual(sum(sum(v) for k, v in m.iteritems()), - sum(xrange(self.N))) + sum(xrange(self.N))) m = ExternalMerger(self.agg, 10) m.mergeCombiners(map(lambda (x, y): (x, [y]), self.data * 3)) self.assertTrue(m.spills >= 1) self.assertEqual(sum(sum(v) for k, v in m.iteritems()), - sum(xrange(self.N)) * 3) + sum(xrange(self.N)) * 3) def test_huge_dataset(self): m = ExternalMerger(self.agg, 10) m.mergeCombiners(map(lambda (k, v): (k, [str(v)]), self.data * 10)) self.assertTrue(m.spills >= 1) self.assertEqual(sum(len(v) for k, v in m._recursive_merged_items(0)), - self.N * 10) + self.N * 10) m._cleanup() @@ -188,6 +188,7 @@ def test_add_py_file(self): log4j = self.sc._jvm.org.apache.log4j old_level = log4j.LogManager.getRootLogger().getLevel() log4j.LogManager.getRootLogger().setLevel(log4j.Level.FATAL) + def func(x): from userlibrary import UserClass return UserClass().hello() @@ -355,8 +356,8 @@ def test_sequencefiles(self): self.assertEqual(doubles, ed) bytes = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfbytes/", - "org.apache.hadoop.io.IntWritable", - "org.apache.hadoop.io.BytesWritable").collect()) + "org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.BytesWritable").collect()) ebs = [(1, bytearray('aa', 'utf-8')), (1, bytearray('aa', 'utf-8')), (2, bytearray('aa', 'utf-8')), @@ -428,9 +429,9 @@ def test_sequencefiles(self): self.assertEqual(clazz[0], ec) unbatched_clazz = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfclass/", - "org.apache.hadoop.io.Text", - "org.apache.spark.api.python.TestWritable", - batchSize=1).collect()) + "org.apache.hadoop.io.Text", + "org.apache.spark.api.python.TestWritable", + batchSize=1).collect()) self.assertEqual(unbatched_clazz[0], ec) def test_oldhadoop(self): @@ -443,7 +444,7 @@ def test_oldhadoop(self): self.assertEqual(ints, ei) hellopath = os.path.join(SPARK_HOME, "python/test_support/hello.txt") - oldconf = {"mapred.input.dir" : hellopath} + oldconf = {"mapred.input.dir": hellopath} hello = self.sc.hadoopRDD("org.apache.hadoop.mapred.TextInputFormat", "org.apache.hadoop.io.LongWritable", "org.apache.hadoop.io.Text", @@ -462,7 +463,7 @@ def test_newhadoop(self): self.assertEqual(ints, ei) hellopath = os.path.join(SPARK_HOME, "python/test_support/hello.txt") - newconf = {"mapred.input.dir" : hellopath} + newconf = {"mapred.input.dir": hellopath} hello = self.sc.newAPIHadoopRDD("org.apache.hadoop.mapreduce.lib.input.TextInputFormat", "org.apache.hadoop.io.LongWritable", "org.apache.hadoop.io.Text", @@ -517,6 +518,7 @@ def test_converters(self): (u'\x03', [2.0])] self.assertEqual(maps, em) + class TestOutputFormat(PySparkTestCase): def setUp(self): @@ -574,8 +576,8 @@ def test_sequencefiles(self): def test_oldhadoop(self): basepath = self.tempdir.name dict_data = [(1, {}), - (1, {"row1" : 1.0}), - (2, {"row2" : 2.0})] + (1, {"row1": 1.0}), + (2, {"row2": 2.0})] self.sc.parallelize(dict_data).saveAsHadoopFile( basepath + "/oldhadoop/", "org.apache.hadoop.mapred.SequenceFileOutputFormat", @@ -589,12 +591,13 @@ def test_oldhadoop(self): self.assertEqual(result, dict_data) conf = { - "mapred.output.format.class" : "org.apache.hadoop.mapred.SequenceFileOutputFormat", - "mapred.output.key.class" : "org.apache.hadoop.io.IntWritable", - "mapred.output.value.class" : "org.apache.hadoop.io.MapWritable", - "mapred.output.dir" : basepath + "/olddataset/"} + "mapred.output.format.class": "org.apache.hadoop.mapred.SequenceFileOutputFormat", + "mapred.output.key.class": "org.apache.hadoop.io.IntWritable", + "mapred.output.value.class": "org.apache.hadoop.io.MapWritable", + "mapred.output.dir": basepath + "/olddataset/" + } self.sc.parallelize(dict_data).saveAsHadoopDataset(conf) - input_conf = {"mapred.input.dir" : basepath + "/olddataset/"} + input_conf = {"mapred.input.dir": basepath + "/olddataset/"} old_dataset = sorted(self.sc.hadoopRDD( "org.apache.hadoop.mapred.SequenceFileInputFormat", "org.apache.hadoop.io.IntWritable", @@ -622,14 +625,17 @@ def test_newhadoop(self): valueConverter="org.apache.spark.api.python.WritableToDoubleArrayConverter").collect()) self.assertEqual(result, array_data) - conf = {"mapreduce.outputformat.class" : - "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat", - "mapred.output.key.class" : "org.apache.hadoop.io.IntWritable", - "mapred.output.value.class" : "org.apache.spark.api.python.DoubleArrayWritable", - "mapred.output.dir" : basepath + "/newdataset/"} - self.sc.parallelize(array_data).saveAsNewAPIHadoopDataset(conf, + conf = { + "mapreduce.outputformat.class": + "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat", + "mapred.output.key.class": "org.apache.hadoop.io.IntWritable", + "mapred.output.value.class": "org.apache.spark.api.python.DoubleArrayWritable", + "mapred.output.dir": basepath + "/newdataset/" + } + self.sc.parallelize(array_data).saveAsNewAPIHadoopDataset( + conf, valueConverter="org.apache.spark.api.python.DoubleArrayToWritableConverter") - input_conf = {"mapred.input.dir" : basepath + "/newdataset/"} + input_conf = {"mapred.input.dir": basepath + "/newdataset/"} new_dataset = sorted(self.sc.newAPIHadoopRDD( "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat", "org.apache.hadoop.io.IntWritable", @@ -640,7 +646,7 @@ def test_newhadoop(self): def test_newolderror(self): basepath = self.tempdir.name - rdd = self.sc.parallelize(range(1, 4)).map(lambda x: (x, "a" * x )) + rdd = self.sc.parallelize(range(1, 4)).map(lambda x: (x, "a" * x)) self.assertRaises(Exception, lambda: rdd.saveAsHadoopFile( basepath + "/newolderror/saveAsHadoopFile/", "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat")) @@ -650,7 +656,7 @@ def test_newolderror(self): def test_bad_inputs(self): basepath = self.tempdir.name - rdd = self.sc.parallelize(range(1, 4)).map(lambda x: (x, "a" * x )) + rdd = self.sc.parallelize(range(1, 4)).map(lambda x: (x, "a" * x)) self.assertRaises(Exception, lambda: rdd.saveAsHadoopFile( basepath + "/badinputs/saveAsHadoopFile/", "org.apache.hadoop.mapred.NotValidOutputFormat")) @@ -685,30 +691,32 @@ def test_reserialization(self): result1 = sorted(self.sc.sequenceFile(basepath + "/reserialize/sequence").collect()) self.assertEqual(result1, data) - rdd.saveAsHadoopFile(basepath + "/reserialize/hadoop", - "org.apache.hadoop.mapred.SequenceFileOutputFormat") + rdd.saveAsHadoopFile( + basepath + "/reserialize/hadoop", + "org.apache.hadoop.mapred.SequenceFileOutputFormat") result2 = sorted(self.sc.sequenceFile(basepath + "/reserialize/hadoop").collect()) self.assertEqual(result2, data) - rdd.saveAsNewAPIHadoopFile(basepath + "/reserialize/newhadoop", - "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat") + rdd.saveAsNewAPIHadoopFile( + basepath + "/reserialize/newhadoop", + "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat") result3 = sorted(self.sc.sequenceFile(basepath + "/reserialize/newhadoop").collect()) self.assertEqual(result3, data) conf4 = { - "mapred.output.format.class" : "org.apache.hadoop.mapred.SequenceFileOutputFormat", - "mapred.output.key.class" : "org.apache.hadoop.io.IntWritable", - "mapred.output.value.class" : "org.apache.hadoop.io.IntWritable", - "mapred.output.dir" : basepath + "/reserialize/dataset"} + "mapred.output.format.class": "org.apache.hadoop.mapred.SequenceFileOutputFormat", + "mapred.output.key.class": "org.apache.hadoop.io.IntWritable", + "mapred.output.value.class": "org.apache.hadoop.io.IntWritable", + "mapred.output.dir": basepath + "/reserialize/dataset"} rdd.saveAsHadoopDataset(conf4) result4 = sorted(self.sc.sequenceFile(basepath + "/reserialize/dataset").collect()) self.assertEqual(result4, data) - conf5 = {"mapreduce.outputformat.class" : - "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat", - "mapred.output.key.class" : "org.apache.hadoop.io.IntWritable", - "mapred.output.value.class" : "org.apache.hadoop.io.IntWritable", - "mapred.output.dir" : basepath + "/reserialize/newdataset"} + conf5 = {"mapreduce.outputformat.class": + "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat", + "mapred.output.key.class": "org.apache.hadoop.io.IntWritable", + "mapred.output.value.class": "org.apache.hadoop.io.IntWritable", + "mapred.output.dir": basepath + "/reserialize/newdataset"} rdd.saveAsNewAPIHadoopDataset(conf5) result5 = sorted(self.sc.sequenceFile(basepath + "/reserialize/newdataset").collect()) self.assertEqual(result5, data) @@ -719,25 +727,28 @@ def test_unbatched_save_and_read(self): self.sc.parallelize(ei, numSlices=len(ei)).saveAsSequenceFile( basepath + "/unbatched/") - unbatched_sequence = sorted(self.sc.sequenceFile(basepath + "/unbatched/", + unbatched_sequence = sorted(self.sc.sequenceFile( + basepath + "/unbatched/", batchSize=1).collect()) self.assertEqual(unbatched_sequence, ei) - unbatched_hadoopFile = sorted(self.sc.hadoopFile(basepath + "/unbatched/", + unbatched_hadoopFile = sorted(self.sc.hadoopFile( + basepath + "/unbatched/", "org.apache.hadoop.mapred.SequenceFileInputFormat", "org.apache.hadoop.io.IntWritable", "org.apache.hadoop.io.Text", batchSize=1).collect()) self.assertEqual(unbatched_hadoopFile, ei) - unbatched_newAPIHadoopFile = sorted(self.sc.newAPIHadoopFile(basepath + "/unbatched/", + unbatched_newAPIHadoopFile = sorted(self.sc.newAPIHadoopFile( + basepath + "/unbatched/", "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat", "org.apache.hadoop.io.IntWritable", "org.apache.hadoop.io.Text", batchSize=1).collect()) self.assertEqual(unbatched_newAPIHadoopFile, ei) - oldconf = {"mapred.input.dir" : basepath + "/unbatched/"} + oldconf = {"mapred.input.dir": basepath + "/unbatched/"} unbatched_hadoopRDD = sorted(self.sc.hadoopRDD( "org.apache.hadoop.mapred.SequenceFileInputFormat", "org.apache.hadoop.io.IntWritable", @@ -746,7 +757,7 @@ def test_unbatched_save_and_read(self): batchSize=1).collect()) self.assertEqual(unbatched_hadoopRDD, ei) - newconf = {"mapred.input.dir" : basepath + "/unbatched/"} + newconf = {"mapred.input.dir": basepath + "/unbatched/"} unbatched_newAPIHadoopRDD = sorted(self.sc.newAPIHadoopRDD( "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat", "org.apache.hadoop.io.IntWritable", @@ -763,7 +774,9 @@ def test_malformed_RDD(self): self.assertRaises(Exception, lambda: rdd.saveAsSequenceFile( basepath + "/malformed/sequence")) + class TestDaemon(unittest.TestCase): + def connect(self, port): from socket import socket, AF_INET, SOCK_STREAM sock = socket(AF_INET, SOCK_STREAM) @@ -810,12 +823,15 @@ def test_termination_sigterm(self): class TestWorker(PySparkTestCase): + def test_cancel_task(self): temp = tempfile.NamedTemporaryFile(delete=True) temp.close() path = temp.name + def sleep(x): - import os, time + import os + import time with open(path, 'w') as f: f.write("%d %d" % (os.getppid(), os.getpid())) time.sleep(100) @@ -845,7 +861,7 @@ def run(): os.kill(worker_pid, 0) time.sleep(0.1) except OSError: - break # worker was killed + break # worker was killed else: self.fail("worker has not been killed after 5 seconds") @@ -855,12 +871,13 @@ def run(): self.fail("daemon had been killed") def test_fd_leak(self): - N = 1100 # fd limit is 1024 by default + N = 1100 # fd limit is 1024 by default rdd = self.sc.parallelize(range(N), N) self.assertEquals(N, rdd.count()) class TestSparkSubmit(unittest.TestCase): + def setUp(self): self.programDir = tempfile.mkdtemp() self.sparkSubmit = os.path.join(os.environ.get("SPARK_HOME"), "bin", "spark-submit") @@ -953,9 +970,9 @@ def test_module_dependency_on_cluster(self): |def myfunc(x): | return x + 1 """) - proc = subprocess.Popen( - [self.sparkSubmit, "--py-files", zip, "--master", "local-cluster[1,1,512]", script], - stdout=subprocess.PIPE) + proc = subprocess.Popen([self.sparkSubmit, "--py-files", zip, "--master", + "local-cluster[1,1,512]", script], + stdout=subprocess.PIPE) out, err = proc.communicate() self.assertEqual(0, proc.returncode) self.assertIn("[2, 3, 4]", out) @@ -981,6 +998,7 @@ def test_single_script_on_cluster(self): @unittest.skipIf(not _have_scipy, "SciPy not installed") class SciPyTests(PySparkTestCase): + """General PySpark tests that depend on scipy """ def test_serialize(self): @@ -993,15 +1011,16 @@ def test_serialize(self): @unittest.skipIf(not _have_numpy, "NumPy not installed") class NumPyTests(PySparkTestCase): + """General PySpark tests that depend on numpy """ def test_statcounter_array(self): - x = self.sc.parallelize([np.array([1.0,1.0]), np.array([2.0,2.0]), np.array([3.0,3.0])]) + x = self.sc.parallelize([np.array([1.0, 1.0]), np.array([2.0, 2.0]), np.array([3.0, 3.0])]) s = x.stats() - self.assertSequenceEqual([2.0,2.0], s.mean().tolist()) - self.assertSequenceEqual([1.0,1.0], s.min().tolist()) - self.assertSequenceEqual([3.0,3.0], s.max().tolist()) - self.assertSequenceEqual([1.0,1.0], s.sampleStdev().tolist()) + self.assertSequenceEqual([2.0, 2.0], s.mean().tolist()) + self.assertSequenceEqual([1.0, 1.0], s.min().tolist()) + self.assertSequenceEqual([3.0, 3.0], s.max().tolist()) + self.assertSequenceEqual([1.0, 1.0], s.sampleStdev().tolist()) if __name__ == "__main__": diff --git a/python/test_support/userlibrary.py b/python/test_support/userlibrary.py index 8e4a6292bc17c..73fd26e71f10d 100755 --- a/python/test_support/userlibrary.py +++ b/python/test_support/userlibrary.py @@ -19,6 +19,8 @@ Used to test shipping of code depenencies with SparkContext.addPyFile(). """ + class UserClass(object): + def hello(self): return "Hello World!" diff --git a/tox.ini b/tox.ini index 44766e529bf7f..a1fefdd0e176f 100644 --- a/tox.ini +++ b/tox.ini @@ -15,3 +15,4 @@ [pep8] max-line-length=100 +exclude=cloudpickle.py From a65c9ac11e7075c2d7a925772273b9b7cf9586d6 Mon Sep 17 00:00:00 2001 From: Sandy Ryza Date: Wed, 6 Aug 2014 13:10:33 -0700 Subject: [PATCH 059/231] SPARK-2566. Update ShuffleWriteMetrics incrementally I haven't tested this out on a cluster yet, but wanted to make sure the approach (passing ShuffleWriteMetrics down to DiskBlockObjectWriter) was ok Author: Sandy Ryza Closes #1481 from sryza/sandy-spark-2566 and squashes the following commits: 8090d88 [Sandy Ryza] Fix ExternalSorter b2a62ed [Sandy Ryza] Fix more test failures 8be6218 [Sandy Ryza] Fix test failures and mark a couple variables private c5e68e5 [Sandy Ryza] SPARK-2566. Update ShuffleWriteMetrics incrementally (cherry picked from commit 4e982364426c7d65032e8006c63ca4f9a0d40470) Signed-off-by: Patrick Wendell --- .../apache/spark/executor/TaskMetrics.scala | 4 +- .../shuffle/hash/HashShuffleWriter.scala | 16 ++-- .../shuffle/sort/SortShuffleWriter.scala | 16 ++-- .../apache/spark/storage/BlockManager.scala | 12 +-- .../spark/storage/BlockObjectWriter.scala | 77 ++++++++++--------- .../spark/storage/ShuffleBlockManager.scala | 9 ++- .../collection/ExternalAppendOnlyMap.scala | 18 +++-- .../util/collection/ExternalSorter.scala | 17 ++-- .../storage/BlockObjectWriterSuite.scala | 65 ++++++++++++++++ .../spark/storage/DiskBlockManagerSuite.scala | 9 ++- .../spark/tools/StoragePerfTester.scala | 3 +- 11 files changed, 164 insertions(+), 82 deletions(-) create mode 100644 core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala index 56cd8723a3a22..11a6e10243211 100644 --- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala @@ -190,10 +190,10 @@ class ShuffleWriteMetrics extends Serializable { /** * Number of bytes written for the shuffle by this task */ - var shuffleBytesWritten: Long = _ + @volatile var shuffleBytesWritten: Long = _ /** * Time the task spent blocking on writes to disk or buffer cache, in nanoseconds */ - var shuffleWriteTime: Long = _ + @volatile var shuffleWriteTime: Long = _ } diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala index 45d3b8b9b8725..51e454d9313c9 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala @@ -39,10 +39,14 @@ private[spark] class HashShuffleWriter[K, V]( // we don't try deleting files, etc twice. private var stopping = false + private val writeMetrics = new ShuffleWriteMetrics() + metrics.shuffleWriteMetrics = Some(writeMetrics) + private val blockManager = SparkEnv.get.blockManager private val shuffleBlockManager = blockManager.shuffleBlockManager private val ser = Serializer.getSerializer(dep.serializer.getOrElse(null)) - private val shuffle = shuffleBlockManager.forMapTask(dep.shuffleId, mapId, numOutputSplits, ser) + private val shuffle = shuffleBlockManager.forMapTask(dep.shuffleId, mapId, numOutputSplits, ser, + writeMetrics) /** Write a bunch of records to this task's output */ override def write(records: Iterator[_ <: Product2[K, V]]): Unit = { @@ -99,22 +103,12 @@ private[spark] class HashShuffleWriter[K, V]( private def commitWritesAndBuildStatus(): MapStatus = { // Commit the writes. Get the size of each bucket block (total block size). - var totalBytes = 0L - var totalTime = 0L val compressedSizes = shuffle.writers.map { writer: BlockObjectWriter => writer.commitAndClose() val size = writer.fileSegment().length - totalBytes += size - totalTime += writer.timeWriting() MapOutputTracker.compressSize(size) } - // Update shuffle metrics. - val shuffleMetrics = new ShuffleWriteMetrics - shuffleMetrics.shuffleBytesWritten = totalBytes - shuffleMetrics.shuffleWriteTime = totalTime - metrics.shuffleWriteMetrics = Some(shuffleMetrics) - new MapStatus(blockManager.blockManagerId, compressedSizes) } diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala index 24db2f287a47b..e54e6383d2ccc 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala @@ -52,6 +52,9 @@ private[spark] class SortShuffleWriter[K, V, C]( private var mapStatus: MapStatus = null + private val writeMetrics = new ShuffleWriteMetrics() + context.taskMetrics.shuffleWriteMetrics = Some(writeMetrics) + /** Write a bunch of records to this task's output */ override def write(records: Iterator[_ <: Product2[K, V]]): Unit = { // Get an iterator with the elements for each partition ID @@ -84,13 +87,10 @@ private[spark] class SortShuffleWriter[K, V, C]( val offsets = new Array[Long](numPartitions + 1) val lengths = new Array[Long](numPartitions) - // Statistics - var totalBytes = 0L - var totalTime = 0L - for ((id, elements) <- partitions) { if (elements.hasNext) { - val writer = blockManager.getDiskWriter(blockId, outputFile, ser, fileBufferSize) + val writer = blockManager.getDiskWriter(blockId, outputFile, ser, fileBufferSize, + writeMetrics) for (elem <- elements) { writer.write(elem) } @@ -98,18 +98,12 @@ private[spark] class SortShuffleWriter[K, V, C]( val segment = writer.fileSegment() offsets(id + 1) = segment.offset + segment.length lengths(id) = segment.length - totalTime += writer.timeWriting() - totalBytes += segment.length } else { // The partition is empty; don't create a new writer to avoid writing headers, etc offsets(id + 1) = offsets(id) } } - val shuffleMetrics = new ShuffleWriteMetrics - shuffleMetrics.shuffleBytesWritten = totalBytes - shuffleMetrics.shuffleWriteTime = totalTime - context.taskMetrics.shuffleWriteMetrics = Some(shuffleMetrics) context.taskMetrics.memoryBytesSpilled += sorter.memoryBytesSpilled context.taskMetrics.diskBytesSpilled += sorter.diskBytesSpilled diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 3876cf43e2a7d..8d21b02b747ff 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -29,7 +29,7 @@ import akka.actor.{ActorSystem, Cancellable, Props} import sun.nio.ch.DirectBuffer import org.apache.spark._ -import org.apache.spark.executor.{DataReadMethod, InputMetrics} +import org.apache.spark.executor.{DataReadMethod, InputMetrics, ShuffleWriteMetrics} import org.apache.spark.io.CompressionCodec import org.apache.spark.network._ import org.apache.spark.serializer.Serializer @@ -562,17 +562,19 @@ private[spark] class BlockManager( /** * A short circuited method to get a block writer that can write data directly to disk. - * The Block will be appended to the File specified by filename. This is currently used for - * writing shuffle files out. Callers should handle error cases. + * The Block will be appended to the File specified by filename. Callers should handle error + * cases. */ def getDiskWriter( blockId: BlockId, file: File, serializer: Serializer, - bufferSize: Int): BlockObjectWriter = { + bufferSize: Int, + writeMetrics: ShuffleWriteMetrics): BlockObjectWriter = { val compressStream: OutputStream => OutputStream = wrapForCompression(blockId, _) val syncWrites = conf.getBoolean("spark.shuffle.sync", false) - new DiskBlockObjectWriter(blockId, file, serializer, bufferSize, compressStream, syncWrites) + new DiskBlockObjectWriter(blockId, file, serializer, bufferSize, compressStream, syncWrites, + writeMetrics) } /** diff --git a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala index 01d46e1ffc960..adda971fd7b47 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala @@ -22,6 +22,7 @@ import java.nio.channels.FileChannel import org.apache.spark.Logging import org.apache.spark.serializer.{SerializationStream, Serializer} +import org.apache.spark.executor.ShuffleWriteMetrics /** * An interface for writing JVM objects to some underlying storage. This interface allows @@ -60,41 +61,26 @@ private[spark] abstract class BlockObjectWriter(val blockId: BlockId) { * This is only valid after commitAndClose() has been called. */ def fileSegment(): FileSegment - - /** - * Cumulative time spent performing blocking writes, in ns. - */ - def timeWriting(): Long - - /** - * Number of bytes written so far - */ - def bytesWritten: Long } -/** BlockObjectWriter which writes directly to a file on disk. Appends to the given file. */ +/** + * BlockObjectWriter which writes directly to a file on disk. Appends to the given file. + * The given write metrics will be updated incrementally, but will not necessarily be current until + * commitAndClose is called. + */ private[spark] class DiskBlockObjectWriter( blockId: BlockId, file: File, serializer: Serializer, bufferSize: Int, compressStream: OutputStream => OutputStream, - syncWrites: Boolean) + syncWrites: Boolean, + writeMetrics: ShuffleWriteMetrics) extends BlockObjectWriter(blockId) with Logging { - /** Intercepts write calls and tracks total time spent writing. Not thread safe. */ private class TimeTrackingOutputStream(out: OutputStream) extends OutputStream { - def timeWriting = _timeWriting - private var _timeWriting = 0L - - private def callWithTiming(f: => Unit) = { - val start = System.nanoTime() - f - _timeWriting += (System.nanoTime() - start) - } - def write(i: Int): Unit = callWithTiming(out.write(i)) override def write(b: Array[Byte]) = callWithTiming(out.write(b)) override def write(b: Array[Byte], off: Int, len: Int) = callWithTiming(out.write(b, off, len)) @@ -111,7 +97,11 @@ private[spark] class DiskBlockObjectWriter( private val initialPosition = file.length() private var finalPosition: Long = -1 private var initialized = false - private var _timeWriting = 0L + + /** Calling channel.position() to update the write metrics can be a little bit expensive, so we + * only call it every N writes */ + private var writesSinceMetricsUpdate = 0 + private var lastPosition = initialPosition override def open(): BlockObjectWriter = { fos = new FileOutputStream(file, true) @@ -128,14 +118,11 @@ private[spark] class DiskBlockObjectWriter( if (syncWrites) { // Force outstanding writes to disk and track how long it takes objOut.flush() - val start = System.nanoTime() - fos.getFD.sync() - _timeWriting += System.nanoTime() - start + def sync = fos.getFD.sync() + callWithTiming(sync) } objOut.close() - _timeWriting += ts.timeWriting - channel = null bs = null fos = null @@ -153,6 +140,7 @@ private[spark] class DiskBlockObjectWriter( // serializer stream and the lower level stream. objOut.flush() bs.flush() + updateBytesWritten() close() } finalPosition = file.length() @@ -162,6 +150,8 @@ private[spark] class DiskBlockObjectWriter( // truncating the file to its initial position. override def revertPartialWritesAndClose() { try { + writeMetrics.shuffleBytesWritten -= (lastPosition - initialPosition) + if (initialized) { objOut.flush() bs.flush() @@ -184,19 +174,36 @@ private[spark] class DiskBlockObjectWriter( if (!initialized) { open() } + objOut.writeObject(value) + + if (writesSinceMetricsUpdate == 32) { + writesSinceMetricsUpdate = 0 + updateBytesWritten() + } else { + writesSinceMetricsUpdate += 1 + } } override def fileSegment(): FileSegment = { - new FileSegment(file, initialPosition, bytesWritten) + new FileSegment(file, initialPosition, finalPosition - initialPosition) } - // Only valid if called after close() - override def timeWriting() = _timeWriting + private def updateBytesWritten() { + val pos = channel.position() + writeMetrics.shuffleBytesWritten += (pos - lastPosition) + lastPosition = pos + } + + private def callWithTiming(f: => Unit) = { + val start = System.nanoTime() + f + writeMetrics.shuffleWriteTime += (System.nanoTime() - start) + } - // Only valid if called after commit() - override def bytesWritten: Long = { - assert(finalPosition != -1, "bytesWritten is only valid after successful commit()") - finalPosition - initialPosition + // For testing + private[spark] def flush() { + objOut.flush() + bs.flush() } } diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala index f9fdffae8bd8f..3565719b54545 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala @@ -29,6 +29,7 @@ import org.apache.spark.storage.ShuffleBlockManager.ShuffleFileGroup import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashMap} import org.apache.spark.util.collection.{PrimitiveKeyOpenHashMap, PrimitiveVector} import org.apache.spark.shuffle.sort.SortShuffleManager +import org.apache.spark.executor.ShuffleWriteMetrics /** A group of writers for a ShuffleMapTask, one writer per reducer. */ private[spark] trait ShuffleWriterGroup { @@ -111,7 +112,8 @@ class ShuffleBlockManager(blockManager: BlockManager) extends Logging { * Get a ShuffleWriterGroup for the given map task, which will register it as complete * when the writers are closed successfully */ - def forMapTask(shuffleId: Int, mapId: Int, numBuckets: Int, serializer: Serializer) = { + def forMapTask(shuffleId: Int, mapId: Int, numBuckets: Int, serializer: Serializer, + writeMetrics: ShuffleWriteMetrics) = { new ShuffleWriterGroup { shuffleStates.putIfAbsent(shuffleId, new ShuffleState(numBuckets)) private val shuffleState = shuffleStates(shuffleId) @@ -121,7 +123,8 @@ class ShuffleBlockManager(blockManager: BlockManager) extends Logging { fileGroup = getUnusedFileGroup() Array.tabulate[BlockObjectWriter](numBuckets) { bucketId => val blockId = ShuffleBlockId(shuffleId, mapId, bucketId) - blockManager.getDiskWriter(blockId, fileGroup(bucketId), serializer, bufferSize) + blockManager.getDiskWriter(blockId, fileGroup(bucketId), serializer, bufferSize, + writeMetrics) } } else { Array.tabulate[BlockObjectWriter](numBuckets) { bucketId => @@ -136,7 +139,7 @@ class ShuffleBlockManager(blockManager: BlockManager) extends Logging { logWarning(s"Failed to remove existing shuffle file $blockFile") } } - blockManager.getDiskWriter(blockId, blockFile, serializer, bufferSize) + blockManager.getDiskWriter(blockId, blockFile, serializer, bufferSize, writeMetrics) } } diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala index 260a5c3888aa7..9f85b94a70800 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala @@ -31,6 +31,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.serializer.{DeserializationStream, Serializer} import org.apache.spark.storage.{BlockId, BlockManager} import org.apache.spark.util.collection.ExternalAppendOnlyMap.HashComparator +import org.apache.spark.executor.ShuffleWriteMetrics /** * :: DeveloperApi :: @@ -102,6 +103,10 @@ class ExternalAppendOnlyMap[K, V, C]( private var _diskBytesSpilled = 0L private val fileBufferSize = sparkConf.getInt("spark.shuffle.file.buffer.kb", 32) * 1024 + + // Write metrics for current spill + private var curWriteMetrics: ShuffleWriteMetrics = _ + private val keyComparator = new HashComparator[K] private val ser = serializer.newInstance() @@ -172,7 +177,9 @@ class ExternalAppendOnlyMap[K, V, C]( logInfo("Thread %d spilling in-memory map of %d MB to disk (%d time%s so far)" .format(threadId, mapSize / (1024 * 1024), spillCount, if (spillCount > 1) "s" else "")) val (blockId, file) = diskBlockManager.createTempBlock() - var writer = blockManager.getDiskWriter(blockId, file, serializer, fileBufferSize) + curWriteMetrics = new ShuffleWriteMetrics() + var writer = blockManager.getDiskWriter(blockId, file, serializer, fileBufferSize, + curWriteMetrics) var objectsWritten = 0 // List of batch sizes (bytes) in the order they are written to disk @@ -183,9 +190,8 @@ class ExternalAppendOnlyMap[K, V, C]( val w = writer writer = null w.commitAndClose() - val bytesWritten = w.bytesWritten - batchSizes.append(bytesWritten) - _diskBytesSpilled += bytesWritten + _diskBytesSpilled += curWriteMetrics.shuffleBytesWritten + batchSizes.append(curWriteMetrics.shuffleBytesWritten) objectsWritten = 0 } @@ -199,7 +205,9 @@ class ExternalAppendOnlyMap[K, V, C]( if (objectsWritten == serializerBatchSize) { flush() - writer = blockManager.getDiskWriter(blockId, file, serializer, fileBufferSize) + curWriteMetrics = new ShuffleWriteMetrics() + writer = blockManager.getDiskWriter(blockId, file, serializer, fileBufferSize, + curWriteMetrics) } } if (objectsWritten > 0) { diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 3f93afd57b3ad..eb4849ebc6e52 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -28,6 +28,7 @@ import com.google.common.io.ByteStreams import org.apache.spark.{Aggregator, SparkEnv, Logging, Partitioner} import org.apache.spark.serializer.{DeserializationStream, Serializer} import org.apache.spark.storage.BlockId +import org.apache.spark.executor.ShuffleWriteMetrics /** * Sorts and potentially merges a number of key-value pairs of type (K, V) to produce key-combiner @@ -112,11 +113,14 @@ private[spark] class ExternalSorter[K, V, C]( // What threshold of elementsRead we start estimating map size at. private val trackMemoryThreshold = 1000 - // Spilling statistics + // Total spilling statistics private var spillCount = 0 private var _memoryBytesSpilled = 0L private var _diskBytesSpilled = 0L + // Write metrics for current spill + private var curWriteMetrics: ShuffleWriteMetrics = _ + // How much of the shared memory pool this collection has claimed private var myMemoryThreshold = 0L @@ -239,7 +243,8 @@ private[spark] class ExternalSorter[K, V, C]( logInfo("Thread %d spilling in-memory batch of %d MB to disk (%d spill%s so far)" .format(threadId, memorySize / (1024 * 1024), spillCount, if (spillCount > 1) "s" else "")) val (blockId, file) = diskBlockManager.createTempBlock() - var writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize) + curWriteMetrics = new ShuffleWriteMetrics() + var writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, curWriteMetrics) var objectsWritten = 0 // Objects written since the last flush // List of batch sizes (bytes) in the order they are written to disk @@ -254,9 +259,8 @@ private[spark] class ExternalSorter[K, V, C]( val w = writer writer = null w.commitAndClose() - val bytesWritten = w.bytesWritten - batchSizes.append(bytesWritten) - _diskBytesSpilled += bytesWritten + _diskBytesSpilled += curWriteMetrics.shuffleBytesWritten + batchSizes.append(curWriteMetrics.shuffleBytesWritten) objectsWritten = 0 } @@ -275,7 +279,8 @@ private[spark] class ExternalSorter[K, V, C]( if (objectsWritten == serializerBatchSize) { flush() - writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize) + curWriteMetrics = new ShuffleWriteMetrics() + writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, curWriteMetrics) } } if (objectsWritten > 0) { diff --git a/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala new file mode 100644 index 0000000000000..bbc7e1357b90d --- /dev/null +++ b/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.storage + +import org.scalatest.FunSuite +import java.io.File +import org.apache.spark.executor.ShuffleWriteMetrics +import org.apache.spark.serializer.JavaSerializer +import org.apache.spark.SparkConf + +class BlockObjectWriterSuite extends FunSuite { + test("verify write metrics") { + val file = new File("somefile") + file.deleteOnExit() + val writeMetrics = new ShuffleWriteMetrics() + val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file, + new JavaSerializer(new SparkConf()), 1024, os => os, true, writeMetrics) + + writer.write(Long.box(20)) + // Metrics don't update on every write + assert(writeMetrics.shuffleBytesWritten == 0) + // After 32 writes, metrics should update + for (i <- 0 until 32) { + writer.flush() + writer.write(Long.box(i)) + } + assert(writeMetrics.shuffleBytesWritten > 0) + writer.commitAndClose() + assert(file.length() == writeMetrics.shuffleBytesWritten) + } + + test("verify write metrics on revert") { + val file = new File("somefile") + file.deleteOnExit() + val writeMetrics = new ShuffleWriteMetrics() + val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file, + new JavaSerializer(new SparkConf()), 1024, os => os, true, writeMetrics) + + writer.write(Long.box(20)) + // Metrics don't update on every write + assert(writeMetrics.shuffleBytesWritten == 0) + // After 32 writes, metrics should update + for (i <- 0 until 32) { + writer.flush() + writer.write(Long.box(i)) + } + assert(writeMetrics.shuffleBytesWritten > 0) + writer.revertPartialWritesAndClose() + assert(writeMetrics.shuffleBytesWritten == 0) + } +} diff --git a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala index 985ac9394738c..b8299e2ea187f 100644 --- a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala @@ -30,6 +30,7 @@ import org.apache.spark.SparkConf import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.JavaSerializer import org.apache.spark.util.{AkkaUtils, Utils} +import org.apache.spark.executor.ShuffleWriteMetrics class DiskBlockManagerSuite extends FunSuite with BeforeAndAfterEach with BeforeAndAfterAll { private val testConf = new SparkConf(false) @@ -153,7 +154,7 @@ class DiskBlockManagerSuite extends FunSuite with BeforeAndAfterEach with Before val shuffleManager = store.shuffleBlockManager - val shuffle1 = shuffleManager.forMapTask(1, 1, 1, serializer) + val shuffle1 = shuffleManager.forMapTask(1, 1, 1, serializer, new ShuffleWriteMetrics) for (writer <- shuffle1.writers) { writer.write("test1") writer.write("test2") @@ -165,7 +166,8 @@ class DiskBlockManagerSuite extends FunSuite with BeforeAndAfterEach with Before val shuffle1Segment = shuffle1.writers(0).fileSegment() shuffle1.releaseWriters(success = true) - val shuffle2 = shuffleManager.forMapTask(1, 2, 1, new JavaSerializer(testConf)) + val shuffle2 = shuffleManager.forMapTask(1, 2, 1, new JavaSerializer(testConf), + new ShuffleWriteMetrics) for (writer <- shuffle2.writers) { writer.write("test3") @@ -183,7 +185,8 @@ class DiskBlockManagerSuite extends FunSuite with BeforeAndAfterEach with Before // of block based on remaining data in file : which could mess things up when there is concurrent read // and writes happening to the same shuffle group. - val shuffle3 = shuffleManager.forMapTask(1, 3, 1, new JavaSerializer(testConf)) + val shuffle3 = shuffleManager.forMapTask(1, 3, 1, new JavaSerializer(testConf), + new ShuffleWriteMetrics) for (writer <- shuffle3.writers) { writer.write("test3") writer.write("test4") diff --git a/tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala b/tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala index 8a05fcb449aa6..17bf7c2541d13 100644 --- a/tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala +++ b/tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala @@ -23,6 +23,7 @@ import java.util.concurrent.atomic.AtomicLong import org.apache.spark.SparkContext import org.apache.spark.serializer.KryoSerializer import org.apache.spark.util.Utils +import org.apache.spark.executor.ShuffleWriteMetrics /** * Internal utility for micro-benchmarking shuffle write performance. @@ -56,7 +57,7 @@ object StoragePerfTester { def writeOutputBytes(mapId: Int, total: AtomicLong) = { val shuffle = blockManager.shuffleBlockManager.forMapTask(1, mapId, numOutputSplits, - new KryoSerializer(sc.conf)) + new KryoSerializer(sc.conf), new ShuffleWriteMetrics()) val writers = shuffle.writers for (i <- 1 to recordsPerMap) { writers(i % numOutputSplits).write(writeData) From e654cfdd02e56fd3aaf6b784dcd25cb9ec35aece Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Wed, 6 Aug 2014 14:07:51 -0700 Subject: [PATCH 060/231] [SPARK-2852][MLLIB] API consistency for `mllib.feature` This is part of SPARK-2828: 1. added a Java-friendly fit method to Word2Vec with tests 2. change DeveloperApi to Experimental for Normalizer & StandardScaler 3. change default feature dimension to 2^20 in HashingTF Author: Xiangrui Meng Closes #1807 from mengxr/feature-api-check and squashes the following commits: 773c1a9 [Xiangrui Meng] change default numFeatures to 2^20 in HashingTF change annotation from DeveloperApi to Experimental in Normalizer and StandardScaler 883e122 [Xiangrui Meng] add @Experimental to Word2VecModel add a Java-friendly method to Word2Vec.fit with tests (cherry picked from commit 25cff1019da9d6cfc486a31d035b372ea5fbdfd2) Signed-off-by: Xiangrui Meng --- .../spark/mllib/feature/HashingTF.scala | 4 +- .../spark/mllib/feature/Normalizer.scala | 6 +- .../spark/mllib/feature/StandardScaler.scala | 6 +- .../apache/spark/mllib/feature/Word2Vec.scala | 19 +++++- .../mllib/feature/JavaWord2VecSuite.java | 66 +++++++++++++++++++ 5 files changed, 91 insertions(+), 10 deletions(-) create mode 100644 mllib/src/test/java/org/apache/spark/mllib/feature/JavaWord2VecSuite.java diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala index 0f6d5809e098f..c53475818395f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala @@ -32,12 +32,12 @@ import org.apache.spark.util.Utils * :: Experimental :: * Maps a sequence of terms to their term frequencies using the hashing trick. * - * @param numFeatures number of features (default: 1000000) + * @param numFeatures number of features (default: 2^20^) */ @Experimental class HashingTF(val numFeatures: Int) extends Serializable { - def this() = this(1000000) + def this() = this(1 << 20) /** * Returns the index of the input term. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala index ea9fd0a80d8e0..3afb47767281c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala @@ -19,11 +19,11 @@ package org.apache.spark.mllib.feature import breeze.linalg.{DenseVector => BDV, SparseVector => BSV} -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.Experimental import org.apache.spark.mllib.linalg.{Vector, Vectors} /** - * :: DeveloperApi :: + * :: Experimental :: * Normalizes samples individually to unit L^p^ norm * * For any 1 <= p < Double.PositiveInfinity, normalizes samples using @@ -33,7 +33,7 @@ import org.apache.spark.mllib.linalg.{Vector, Vectors} * * @param p Normalization in L^p^ space, p = 2 by default. */ -@DeveloperApi +@Experimental class Normalizer(p: Double) extends VectorTransformer { def this() = this(2) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala index cc2d7579c2901..e6c9f8f67df63 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala @@ -19,14 +19,14 @@ package org.apache.spark.mllib.feature import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, Vector => BV} -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.Experimental import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.rdd.RDDFunctions._ import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer import org.apache.spark.rdd.RDD /** - * :: DeveloperApi :: + * :: Experimental :: * Standardizes features by removing the mean and scaling to unit variance using column summary * statistics on the samples in the training set. * @@ -34,7 +34,7 @@ import org.apache.spark.rdd.RDD * dense output, so this does not work on sparse input and will raise an exception. * @param withStd True by default. Scales the data to unit standard deviation. */ -@DeveloperApi +@Experimental class StandardScaler(withMean: Boolean, withStd: Boolean) extends VectorTransformer { def this() = this(false, true) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index 3bf44ad7c44e3..395037e1ec47c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -17,6 +17,9 @@ package org.apache.spark.mllib.feature +import java.lang.{Iterable => JavaIterable} + +import scala.collection.JavaConverters._ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer @@ -25,6 +28,7 @@ import com.github.fommil.netlib.BLAS.{getInstance => blas} import org.apache.spark.Logging import org.apache.spark.SparkContext._ import org.apache.spark.annotation.Experimental +import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.rdd.RDDFunctions._ import org.apache.spark.rdd._ @@ -239,7 +243,7 @@ class Word2Vec extends Serializable with Logging { a += 1 } } - + /** * Computes the vector representation of each word in vocabulary. * @param dataset an RDD of words @@ -369,11 +373,22 @@ class Word2Vec extends Serializable with Logging { new Word2VecModel(word2VecMap.toMap) } + + /** + * Computes the vector representation of each word in vocabulary (Java version). + * @param dataset a JavaRDD of words + * @return a Word2VecModel + */ + def fit[S <: JavaIterable[String]](dataset: JavaRDD[S]): Word2VecModel = { + fit(dataset.rdd.map(_.asScala)) + } } /** -* Word2Vec model + * :: Experimental :: + * Word2Vec model */ +@Experimental class Word2VecModel private[mllib] ( private val model: Map[String, Array[Float]]) extends Serializable { diff --git a/mllib/src/test/java/org/apache/spark/mllib/feature/JavaWord2VecSuite.java b/mllib/src/test/java/org/apache/spark/mllib/feature/JavaWord2VecSuite.java new file mode 100644 index 0000000000000..fb7afe8c6434b --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/mllib/feature/JavaWord2VecSuite.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.feature; + +import java.io.Serializable; +import java.util.List; + +import scala.Tuple2; + +import com.google.common.collect.Lists; +import com.google.common.base.Strings; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; + +public class JavaWord2VecSuite implements Serializable { + private transient JavaSparkContext sc; + + @Before + public void setUp() { + sc = new JavaSparkContext("local", "JavaWord2VecSuite"); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + } + + @Test + @SuppressWarnings("unchecked") + public void word2Vec() { + // The tests are to check Java compatibility. + String sentence = Strings.repeat("a b ", 100) + Strings.repeat("a c ", 10); + List words = Lists.newArrayList(sentence.split(" ")); + List> localDoc = Lists.newArrayList(words, words); + JavaRDD> doc = sc.parallelize(localDoc); + Word2Vec word2vec = new Word2Vec() + .setVectorSize(10) + .setSeed(42L); + Word2VecModel model = word2vec.fit(doc); + Tuple2[] syms = model.findSynonyms("a", 2); + Assert.assertEquals(2, syms.length); + Assert.assertEquals("b", syms[0]._1()); + Assert.assertEquals("c", syms[1]._1()); + } +} From a314e293f40c05991522d145e7d39b460b47f615 Mon Sep 17 00:00:00 2001 From: RJ Nowling Date: Wed, 6 Aug 2014 14:12:21 -0700 Subject: [PATCH 061/231] [PySpark] Add blanklines to Python docstrings so example code renders correctly Author: RJ Nowling Closes #1808 from rnowling/pyspark_docs and squashes the following commits: c06d774 [RJ Nowling] Add blanklines to Python docstrings so example code renders correctly (cherry picked from commit e537b33c63d3fb373fe41deaa607d72e76e3906b) Signed-off-by: Xiangrui Meng --- python/pyspark/rdd.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 30b834d2085cd..756e8f35fb03d 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -134,6 +134,7 @@ class MaxHeapQ(object): """ An implementation of MaxHeap. + >>> import pyspark.rdd >>> heap = pyspark.rdd.MaxHeapQ(5) >>> [heap.insert(i) for i in range(10)] @@ -381,6 +382,7 @@ def mapPartitionsWithSplit(self, f, preservesPartitioning=False): def getNumPartitions(self): """ Returns the number of partitions in RDD + >>> rdd = sc.parallelize([1, 2, 3, 4], 2) >>> rdd.getNumPartitions() 2 @@ -570,6 +572,7 @@ def sortByKey(self, ascending=True, numPartitions=None, keyfunc=lambda x: x): """ Sorts this RDD, which is assumed to consist of (key, value) pairs. # noqa + >>> tmp = [('a', 1), ('b', 2), ('1', 3), ('d', 4), ('2', 5)] >>> sc.parallelize(tmp).sortByKey(True, 2).collect() [('1', 3), ('2', 5), ('a', 1), ('b', 2), ('d', 4)] @@ -1209,6 +1212,7 @@ def collectAsMap(self): def keys(self): """ Return an RDD with the keys of each tuple. + >>> m = sc.parallelize([(1, 2), (3, 4)]).keys() >>> m.collect() [1, 3] @@ -1218,6 +1222,7 @@ def keys(self): def values(self): """ Return an RDD with the values of each tuple. + >>> m = sc.parallelize([(1, 2), (3, 4)]).values() >>> m.collect() [2, 4] @@ -1642,6 +1647,7 @@ def repartition(self, numPartitions): Internally, this uses a shuffle to redistribute data. If you are decreasing the number of partitions in this RDD, consider using `coalesce`, which can avoid performing a shuffle. + >>> rdd = sc.parallelize([1,2,3,4,5,6,7], 4) >>> sorted(rdd.glom().collect()) [[1], [2, 3], [4, 5], [6, 7]] @@ -1656,6 +1662,7 @@ def repartition(self, numPartitions): def coalesce(self, numPartitions, shuffle=False): """ Return a new RDD that is reduced into `numPartitions` partitions. + >>> sc.parallelize([1, 2, 3, 4, 5], 3).glom().collect() [[1], [2, 3], [4, 5]] >>> sc.parallelize([1, 2, 3, 4, 5], 3).coalesce(1).glom().collect() @@ -1694,6 +1701,7 @@ def name(self): def setName(self, name): """ Assign a name to this RDD. + >>> rdd1 = sc.parallelize([1,2]) >>> rdd1.setName('RDD1') >>> rdd1.name() @@ -1753,6 +1761,7 @@ class PipelinedRDD(RDD): """ Pipelined maps: + >>> rdd = sc.parallelize([1, 2, 3, 4]) >>> rdd.map(lambda x: 2 * x).cache().map(lambda x: 2 * x).collect() [4, 8, 12, 16] From c2ae0b03669c72f5b842dc0cb4ba1f808c9ef702 Mon Sep 17 00:00:00 2001 From: Gregory Owen Date: Wed, 6 Aug 2014 16:52:00 -0700 Subject: [PATCH 062/231] SPARK-2882: Spark build now checks local maven cache for dependencies Fixes [SPARK-2882](https://issues.apache.org/jira/browse/SPARK-2882) Author: Gregory Owen Closes #1818 from GregOwen/spark-2882 and squashes the following commits: 294446d [Gregory Owen] SPARK-2882: Spark build now checks local maven cache for dependencies (cherry picked from commit 4e008334ee0fb60f9fe8820afa06f7b7f0fa7a6c) Signed-off-by: Patrick Wendell --- project/SparkBuild.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 40b588512ff08..ed587783d5606 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -115,7 +115,8 @@ object SparkBuild extends PomBuild { retrieveManaged := true, retrievePattern := "[type]s/[artifact](-[revision])(-[classifier]).[ext]", publishMavenStyle := true, - + + resolvers += Resolver.mavenLocal, otherResolvers <<= SbtPomKeys.mvnLocalRepository(dotM2 => Seq(Resolver.file("dotM2", dotM2))), publishLocalConfiguration in MavenCompile <<= (packagedArtifacts, deliverLocal, ivyLoggingLevel) map { (arts, _, level) => new PublishConfiguration(None, "dotM2", arts, Seq(), level) From 3f92ce4e2270f3a1bf5303af78763230dd6cca5c Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Wed, 6 Aug 2014 17:27:55 -0700 Subject: [PATCH 063/231] [SPARK-2583] ConnectionManager error reporting This patch modifies the ConnectionManager so that error messages are sent in reply when uncaught exceptions occur during message processing. This prevents message senders from hanging while waiting for an acknowledgment if the remote message processing failed. This is an updated version of sarutak's PR, #1490. The main change is to use Futures / Promises to signal errors. Author: Kousuke Saruta Author: Josh Rosen Closes #1758 from JoshRosen/connection-manager-fixes and squashes the following commits: 68620cb [Josh Rosen] Fix test in BlockFetcherIteratorSuite: 83673de [Josh Rosen] Error ACKs should trigger IOExceptions, so catch only those exceptions in the test. b8bb4d4 [Josh Rosen] Fix manager.id vs managerServer.id typo that broke security tests. 659521f [Josh Rosen] Include previous exception when throwing new one a2f745c [Josh Rosen] Remove sendMessageReliablySync; callers can wait themselves. c01c450 [Josh Rosen] Return Try[Message] from sendMessageReliablySync. f1cd1bb [Josh Rosen] Clean up @sarutak's PR #1490 for [SPARK-2583]: ConnectionManager error reporting 7399c6b [Josh Rosen] Merge remote-tracking branch 'origin/pr/1490' into connection-manager-fixes ee91bb7 [Kousuke Saruta] Modified BufferMessage.scala to keep the spark code style 9dfd0d8 [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into SPARK-2583 e7d9aa6 [Kousuke Saruta] rebase to master 326a17f [Kousuke Saruta] Add test cases to ConnectionManagerSuite.scala for SPARK-2583 2a18d6b [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into SPARK-2583 22d7ebd [Kousuke Saruta] Add test cases to BlockManagerSuite for SPARK-2583 e579302 [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into SPARK-2583 281589c [Kousuke Saruta] Add a test case to BlockFetcherIteratorSuite.scala for fetching block from remote from successfully 0654128 [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into SPARK-2583 ffaa83d [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into SPARK-2583 12d3de8 [Kousuke Saruta] Added BlockFetcherIteratorSuite.scala 4117b8f [Kousuke Saruta] Modified ConnectionManager to be alble to handle error during processing message 717c9c3 [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into SPARK-2583 6635467 [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into SPARK-2583 e2b8c4a [Kousuke Saruta] Modify to propagete error using ConnectionManager (cherry picked from commit 17caae48b3608552dd6e3ae652043831f932ce95) Signed-off-by: Patrick Wendell --- .../apache/spark/network/BufferMessage.scala | 7 +- .../spark/network/ConnectionManager.scala | 143 ++++++++++-------- .../org/apache/spark/network/Message.scala | 2 + .../spark/network/MessageChunkHeader.scala | 7 +- .../org/apache/spark/network/SenderTest.scala | 7 +- .../spark/storage/BlockFetcherIterator.scala | 9 +- .../spark/storage/BlockManagerWorker.scala | 30 ++-- .../network/ConnectionManagerSuite.scala | 38 ++++- .../storage/BlockFetcherIteratorSuite.scala | 98 +++++++++++- .../spark/storage/BlockManagerSuite.scala | 110 +++++++++++++- 10 files changed, 362 insertions(+), 89 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/network/BufferMessage.scala b/core/src/main/scala/org/apache/spark/network/BufferMessage.scala index 04df2f3b0d696..af35f1fc3e459 100644 --- a/core/src/main/scala/org/apache/spark/network/BufferMessage.scala +++ b/core/src/main/scala/org/apache/spark/network/BufferMessage.scala @@ -48,7 +48,7 @@ class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId: val security = if (isSecurityNeg) 1 else 0 if (size == 0 && !gotChunkForSendingOnce) { val newChunk = new MessageChunk( - new MessageChunkHeader(typ, id, 0, 0, ackId, security, senderAddress), null) + new MessageChunkHeader(typ, id, 0, 0, ackId, hasError, security, senderAddress), null) gotChunkForSendingOnce = true return Some(newChunk) } @@ -66,7 +66,8 @@ class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId: } buffer.position(buffer.position + newBuffer.remaining) val newChunk = new MessageChunk(new MessageChunkHeader( - typ, id, size, newBuffer.remaining, ackId, security, senderAddress), newBuffer) + typ, id, size, newBuffer.remaining, ackId, + hasError, security, senderAddress), newBuffer) gotChunkForSendingOnce = true return Some(newChunk) } @@ -88,7 +89,7 @@ class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId: val newBuffer = buffer.slice().limit(chunkSize).asInstanceOf[ByteBuffer] buffer.position(buffer.position + newBuffer.remaining) val newChunk = new MessageChunk(new MessageChunkHeader( - typ, id, size, newBuffer.remaining, ackId, security, senderAddress), newBuffer) + typ, id, size, newBuffer.remaining, ackId, hasError, security, senderAddress), newBuffer) return Some(newChunk) } None diff --git a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala index 4c00225280cce..95f96b8463a01 100644 --- a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala +++ b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala @@ -17,6 +17,7 @@ package org.apache.spark.network +import java.io.IOException import java.nio._ import java.nio.channels._ import java.nio.channels.spi._ @@ -45,16 +46,26 @@ private[spark] class ConnectionManager( name: String = "Connection manager") extends Logging { + /** + * Used by sendMessageReliably to track messages being sent. + * @param message the message that was sent + * @param connectionManagerId the connection manager that sent this message + * @param completionHandler callback that's invoked when the send has completed or failed + */ class MessageStatus( val message: Message, val connectionManagerId: ConnectionManagerId, completionHandler: MessageStatus => Unit) { + /** This is non-None if message has been ack'd */ var ackMessage: Option[Message] = None - var attempted = false - var acked = false - def markDone() { completionHandler(this) } + def markDone(ackMessage: Option[Message]) { + this.synchronized { + this.ackMessage = ackMessage + completionHandler(this) + } + } } private val selector = SelectorProvider.provider.openSelector() @@ -442,11 +453,7 @@ private[spark] class ConnectionManager( messageStatuses.values.filter(_.connectionManagerId == sendingConnectionManagerId) .foreach(status => { logInfo("Notifying " + status) - status.synchronized { - status.attempted = true - status.acked = false - status.markDone() - } + status.markDone(None) }) messageStatuses.retain((i, status) => { @@ -475,11 +482,7 @@ private[spark] class ConnectionManager( for (s <- messageStatuses.values if s.connectionManagerId == sendingConnectionManagerId) { logInfo("Notifying " + s) - s.synchronized { - s.attempted = true - s.acked = false - s.markDone() - } + s.markDone(None) } messageStatuses.retain((i, status) => { @@ -547,13 +550,13 @@ private[spark] class ConnectionManager( val securityMsgResp = SecurityMessage.fromResponse(replyToken, securityMsg.getConnectionId.toString) val message = securityMsgResp.toBufferMessage - if (message == null) throw new Exception("Error creating security message") + if (message == null) throw new IOException("Error creating security message") sendSecurityMessage(waitingConn.getRemoteConnectionManagerId(), message) } catch { case e: Exception => { logError("Error handling sasl client authentication", e) waitingConn.close() - throw new Exception("Error evaluating sasl response: " + e) + throw new IOException("Error evaluating sasl response: ", e) } } } @@ -661,34 +664,39 @@ private[spark] class ConnectionManager( } } } - sentMessageStatus.synchronized { - sentMessageStatus.ackMessage = Some(message) - sentMessageStatus.attempted = true - sentMessageStatus.acked = true - sentMessageStatus.markDone() - } + sentMessageStatus.markDone(Some(message)) } else { - val ackMessage = if (onReceiveCallback != null) { - logDebug("Calling back") - onReceiveCallback(bufferMessage, connectionManagerId) - } else { - logDebug("Not calling back as callback is null") - None - } + var ackMessage : Option[Message] = None + try { + ackMessage = if (onReceiveCallback != null) { + logDebug("Calling back") + onReceiveCallback(bufferMessage, connectionManagerId) + } else { + logDebug("Not calling back as callback is null") + None + } - if (ackMessage.isDefined) { - if (!ackMessage.get.isInstanceOf[BufferMessage]) { - logDebug("Response to " + bufferMessage + " is not a buffer message, it is of type " - + ackMessage.get.getClass) - } else if (!ackMessage.get.asInstanceOf[BufferMessage].hasAckId) { - logDebug("Response to " + bufferMessage + " does not have ack id set") - ackMessage.get.asInstanceOf[BufferMessage].ackId = bufferMessage.id + if (ackMessage.isDefined) { + if (!ackMessage.get.isInstanceOf[BufferMessage]) { + logDebug("Response to " + bufferMessage + " is not a buffer message, it is of type " + + ackMessage.get.getClass) + } else if (!ackMessage.get.asInstanceOf[BufferMessage].hasAckId) { + logDebug("Response to " + bufferMessage + " does not have ack id set") + ackMessage.get.asInstanceOf[BufferMessage].ackId = bufferMessage.id + } + } + } catch { + case e: Exception => { + logError(s"Exception was thrown while processing message", e) + val m = Message.createBufferMessage(bufferMessage.id) + m.hasError = true + ackMessage = Some(m) } + } finally { + sendMessage(connectionManagerId, ackMessage.getOrElse { + Message.createBufferMessage(bufferMessage.id) + }) } - - sendMessage(connectionManagerId, ackMessage.getOrElse { - Message.createBufferMessage(bufferMessage.id) - }) } } case _ => throw new Exception("Unknown type message received") @@ -800,11 +808,7 @@ private[spark] class ConnectionManager( case Some(msgStatus) => { messageStatuses -= message.id logInfo("Notifying " + msgStatus.connectionManagerId) - msgStatus.synchronized { - msgStatus.attempted = true - msgStatus.acked = false - msgStatus.markDone() - } + msgStatus.markDone(None) } case None => { logError("no messageStatus for failed message id: " + message.id) @@ -823,11 +827,28 @@ private[spark] class ConnectionManager( selector.wakeup() } + /** + * Send a message and block until an acknowldgment is received or an error occurs. + * @param connectionManagerId the message's destination + * @param message the message being sent + * @return a Future that either returns the acknowledgment message or captures an exception. + */ def sendMessageReliably(connectionManagerId: ConnectionManagerId, message: Message) - : Future[Option[Message]] = { - val promise = Promise[Option[Message]] - val status = new MessageStatus( - message, connectionManagerId, s => promise.success(s.ackMessage)) + : Future[Message] = { + val promise = Promise[Message]() + val status = new MessageStatus(message, connectionManagerId, s => { + s.ackMessage match { + case None => // Indicates a failure where we either never sent or never got ACK'd + promise.failure(new IOException("sendMessageReliably failed without being ACK'd")) + case Some(ackMessage) => + if (ackMessage.hasError) { + promise.failure( + new IOException("sendMessageReliably failed with ACK that signalled a remote error")) + } else { + promise.success(ackMessage) + } + } + }) messageStatuses.synchronized { messageStatuses += ((message.id, status)) } @@ -835,11 +856,6 @@ private[spark] class ConnectionManager( promise.future } - def sendMessageReliablySync(connectionManagerId: ConnectionManagerId, - message: Message): Option[Message] = { - Await.result(sendMessageReliably(connectionManagerId, message), Duration.Inf) - } - def onReceiveMessage(callback: (Message, ConnectionManagerId) => Option[Message]) { onReceiveCallback = callback } @@ -862,6 +878,7 @@ private[spark] class ConnectionManager( private[spark] object ConnectionManager { + import ExecutionContext.Implicits.global def main(args: Array[String]) { val conf = new SparkConf @@ -896,7 +913,7 @@ private[spark] object ConnectionManager { (0 until count).map(i => { val bufferMessage = Message.createBufferMessage(buffer.duplicate) - manager.sendMessageReliablySync(manager.id, bufferMessage) + Await.result(manager.sendMessageReliably(manager.id, bufferMessage), Duration.Inf) }) println("--------------------------") println() @@ -917,8 +934,10 @@ private[spark] object ConnectionManager { val bufferMessage = Message.createBufferMessage(buffer.duplicate) manager.sendMessageReliably(manager.id, bufferMessage) }).foreach(f => { - val g = Await.result(f, 1 second) - if (!g.isDefined) println("Failed") + f.onFailure { + case e => println("Failed due to " + e) + } + Await.ready(f, 1 second) }) val finishTime = System.currentTimeMillis @@ -952,8 +971,10 @@ private[spark] object ConnectionManager { val bufferMessage = Message.createBufferMessage(buffers(count - 1 - i).duplicate) manager.sendMessageReliably(manager.id, bufferMessage) }).foreach(f => { - val g = Await.result(f, 1 second) - if (!g.isDefined) println("Failed") + f.onFailure { + case e => println("Failed due to " + e) + } + Await.ready(f, 1 second) }) val finishTime = System.currentTimeMillis @@ -982,8 +1003,10 @@ private[spark] object ConnectionManager { val bufferMessage = Message.createBufferMessage(buffer.duplicate) manager.sendMessageReliably(manager.id, bufferMessage) }).foreach(f => { - val g = Await.result(f, 1 second) - if (!g.isDefined) println("Failed") + f.onFailure { + case e => println("Failed due to " + e) + } + Await.ready(f, 1 second) }) val finishTime = System.currentTimeMillis Thread.sleep(1000) diff --git a/core/src/main/scala/org/apache/spark/network/Message.scala b/core/src/main/scala/org/apache/spark/network/Message.scala index 7caccfdbb44f9..04ea50f62918c 100644 --- a/core/src/main/scala/org/apache/spark/network/Message.scala +++ b/core/src/main/scala/org/apache/spark/network/Message.scala @@ -28,6 +28,7 @@ private[spark] abstract class Message(val typ: Long, val id: Int) { var startTime = -1L var finishTime = -1L var isSecurityNeg = false + var hasError = false def size: Int @@ -87,6 +88,7 @@ private[spark] object Message { case BUFFER_MESSAGE => new BufferMessage(header.id, ArrayBuffer(ByteBuffer.allocate(header.totalSize)), header.other) } + newMessage.hasError = header.hasError newMessage.senderAddress = header.address newMessage } diff --git a/core/src/main/scala/org/apache/spark/network/MessageChunkHeader.scala b/core/src/main/scala/org/apache/spark/network/MessageChunkHeader.scala index ead663ede7a1c..f3ecca5f992e0 100644 --- a/core/src/main/scala/org/apache/spark/network/MessageChunkHeader.scala +++ b/core/src/main/scala/org/apache/spark/network/MessageChunkHeader.scala @@ -27,6 +27,7 @@ private[spark] class MessageChunkHeader( val totalSize: Int, val chunkSize: Int, val other: Int, + val hasError: Boolean, val securityNeg: Int, val address: InetSocketAddress) { lazy val buffer = { @@ -41,6 +42,7 @@ private[spark] class MessageChunkHeader( putInt(totalSize). putInt(chunkSize). putInt(other). + put(if (hasError) 1.asInstanceOf[Byte] else 0.asInstanceOf[Byte]). putInt(securityNeg). putInt(ip.size). put(ip). @@ -56,7 +58,7 @@ private[spark] class MessageChunkHeader( private[spark] object MessageChunkHeader { - val HEADER_SIZE = 44 + val HEADER_SIZE = 45 def create(buffer: ByteBuffer): MessageChunkHeader = { if (buffer.remaining != HEADER_SIZE) { @@ -67,13 +69,14 @@ private[spark] object MessageChunkHeader { val totalSize = buffer.getInt() val chunkSize = buffer.getInt() val other = buffer.getInt() + val hasError = buffer.get() != 0 val securityNeg = buffer.getInt() val ipSize = buffer.getInt() val ipBytes = new Array[Byte](ipSize) buffer.get(ipBytes) val ip = InetAddress.getByAddress(ipBytes) val port = buffer.getInt() - new MessageChunkHeader(typ, id, totalSize, chunkSize, other, securityNeg, + new MessageChunkHeader(typ, id, totalSize, chunkSize, other, hasError, securityNeg, new InetSocketAddress(ip, port)) } } diff --git a/core/src/main/scala/org/apache/spark/network/SenderTest.scala b/core/src/main/scala/org/apache/spark/network/SenderTest.scala index b8ea7c2cff9a2..ea2ad104ecae1 100644 --- a/core/src/main/scala/org/apache/spark/network/SenderTest.scala +++ b/core/src/main/scala/org/apache/spark/network/SenderTest.scala @@ -20,6 +20,10 @@ package org.apache.spark.network import java.nio.ByteBuffer import org.apache.spark.{SecurityManager, SparkConf} +import scala.concurrent.Await +import scala.concurrent.duration.Duration +import scala.util.Try + private[spark] object SenderTest { def main(args: Array[String]) { @@ -51,7 +55,8 @@ private[spark] object SenderTest { val dataMessage = Message.createBufferMessage(buffer.duplicate) val startTime = System.currentTimeMillis /* println("Started timer at " + startTime) */ - val responseStr = manager.sendMessageReliablySync(targetConnectionManagerId, dataMessage) + val promise = manager.sendMessageReliably(targetConnectionManagerId, dataMessage) + val responseStr: String = Try(Await.result(promise, Duration.Inf)) .map { response => val buffer = response.asInstanceOf[BufferMessage].buffers(0) new String(buffer.array, "utf-8") diff --git a/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala index ccf830e118ee7..938af6f5b923a 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala @@ -22,6 +22,7 @@ import java.util.concurrent.LinkedBlockingQueue import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashSet import scala.collection.mutable.Queue +import scala.util.{Failure, Success} import io.netty.buffer.ByteBuf @@ -118,8 +119,8 @@ object BlockFetcherIterator { bytesInFlight += req.size val sizeMap = req.blocks.toMap // so we can look up the size of each blockID val future = connectionManager.sendMessageReliably(cmId, blockMessageArray.toBufferMessage) - future.onSuccess { - case Some(message) => { + future.onComplete { + case Success(message) => { val bufferMessage = message.asInstanceOf[BufferMessage] val blockMessageArray = BlockMessageArray.fromBufferMessage(bufferMessage) for (blockMessage <- blockMessageArray) { @@ -135,8 +136,8 @@ object BlockFetcherIterator { logDebug("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime)) } } - case None => { - logError("Could not get block(s) from " + cmId) + case Failure(exception) => { + logError("Could not get block(s) from " + cmId, exception) for ((blockId, size) <- req.blocks) { results.put(new FetchResult(blockId, -1, null)) } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala index c7766a3a65671..bf002a42d5dc5 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala @@ -23,6 +23,10 @@ import org.apache.spark.Logging import org.apache.spark.network._ import org.apache.spark.util.Utils +import scala.concurrent.Await +import scala.concurrent.duration.Duration +import scala.util.{Try, Failure, Success} + /** * A network interface for BlockManager. Each slave should have one * BlockManagerWorker. @@ -44,13 +48,19 @@ private[spark] class BlockManagerWorker(val blockManager: BlockManager) extends val responseMessages = blockMessages.map(processBlockMessage).filter(_ != None).map(_.get) Some(new BlockMessageArray(responseMessages).toBufferMessage) } catch { - case e: Exception => logError("Exception handling buffer message", e) - None + case e: Exception => { + logError("Exception handling buffer message", e) + val errorMessage = Message.createBufferMessage(msg.id) + errorMessage.hasError = true + Some(errorMessage) + } } } case otherMessage: Any => { logError("Unknown type message received: " + otherMessage) - None + val errorMessage = Message.createBufferMessage(msg.id) + errorMessage.hasError = true + Some(errorMessage) } } } @@ -109,9 +119,9 @@ private[spark] object BlockManagerWorker extends Logging { val connectionManager = blockManager.connectionManager val blockMessage = BlockMessage.fromPutBlock(msg) val blockMessageArray = new BlockMessageArray(blockMessage) - val resultMessage = connectionManager.sendMessageReliablySync( - toConnManagerId, blockMessageArray.toBufferMessage) - resultMessage.isDefined + val resultMessage = Try(Await.result(connectionManager.sendMessageReliably( + toConnManagerId, blockMessageArray.toBufferMessage), Duration.Inf)) + resultMessage.isSuccess } def syncGetBlock(msg: GetBlock, toConnManagerId: ConnectionManagerId): ByteBuffer = { @@ -119,10 +129,10 @@ private[spark] object BlockManagerWorker extends Logging { val connectionManager = blockManager.connectionManager val blockMessage = BlockMessage.fromGetBlock(msg) val blockMessageArray = new BlockMessageArray(blockMessage) - val responseMessage = connectionManager.sendMessageReliablySync( - toConnManagerId, blockMessageArray.toBufferMessage) + val responseMessage = Try(Await.result(connectionManager.sendMessageReliably( + toConnManagerId, blockMessageArray.toBufferMessage), Duration.Inf)) responseMessage match { - case Some(message) => { + case Success(message) => { val bufferMessage = message.asInstanceOf[BufferMessage] logDebug("Response message received " + bufferMessage) BlockMessageArray.fromBufferMessage(bufferMessage).foreach(blockMessage => { @@ -130,7 +140,7 @@ private[spark] object BlockManagerWorker extends Logging { return blockMessage.getData }) } - case None => logDebug("No response message received") + case Failure(exception) => logDebug("No response message received") } null } diff --git a/core/src/test/scala/org/apache/spark/network/ConnectionManagerSuite.scala b/core/src/test/scala/org/apache/spark/network/ConnectionManagerSuite.scala index 415ad8c432c12..846537df003df 100644 --- a/core/src/test/scala/org/apache/spark/network/ConnectionManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/network/ConnectionManagerSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.network +import java.io.IOException import java.nio._ import org.apache.spark.{SecurityManager, SparkConf} @@ -25,6 +26,7 @@ import org.scalatest.FunSuite import scala.concurrent.{Await, TimeoutException} import scala.concurrent.duration._ import scala.language.postfixOps +import scala.util.Try /** * Test the ConnectionManager with various security settings. @@ -46,7 +48,7 @@ class ConnectionManagerSuite extends FunSuite { buffer.flip val bufferMessage = Message.createBufferMessage(buffer.duplicate) - manager.sendMessageReliablySync(manager.id, bufferMessage) + Await.result(manager.sendMessageReliably(manager.id, bufferMessage), 10 seconds) assert(receivedMessage == true) @@ -79,7 +81,7 @@ class ConnectionManagerSuite extends FunSuite { (0 until count).map(i => { val bufferMessage = Message.createBufferMessage(buffer.duplicate) - manager.sendMessageReliablySync(managerServer.id, bufferMessage) + Await.result(manager.sendMessageReliably(managerServer.id, bufferMessage), 10 seconds) }) assert(numReceivedServerMessages == 10) @@ -118,7 +120,10 @@ class ConnectionManagerSuite extends FunSuite { val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) buffer.flip val bufferMessage = Message.createBufferMessage(buffer.duplicate) - manager.sendMessageReliablySync(managerServer.id, bufferMessage) + // Expect managerServer to close connection, which we'll report as an error: + intercept[IOException] { + Await.result(manager.sendMessageReliably(managerServer.id, bufferMessage), 10 seconds) + } assert(numReceivedServerMessages == 0) assert(numReceivedMessages == 0) @@ -163,6 +168,8 @@ class ConnectionManagerSuite extends FunSuite { val g = Await.result(f, 1 second) assert(false) } catch { + case i: IOException => + assert(true) case e: TimeoutException => { // we should timeout here since the client can't do the negotiation assert(true) @@ -209,7 +216,6 @@ class ConnectionManagerSuite extends FunSuite { }).foreach(f => { try { val g = Await.result(f, 1 second) - if (!g.isDefined) assert(false) else assert(true) } catch { case e: Exception => { assert(false) @@ -223,7 +229,31 @@ class ConnectionManagerSuite extends FunSuite { managerServer.stop() } + test("Ack error message") { + val conf = new SparkConf + conf.set("spark.authenticate", "false") + val securityManager = new SecurityManager(conf) + val manager = new ConnectionManager(0, conf, securityManager) + val managerServer = new ConnectionManager(0, conf, securityManager) + managerServer.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { + throw new Exception + }) + + val size = 10 * 1024 * 1024 + val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) + buffer.flip + val bufferMessage = Message.createBufferMessage(buffer) + + val future = manager.sendMessageReliably(managerServer.id, bufferMessage) + + intercept[IOException] { + Await.result(future, 1 second) + } + manager.stop() + managerServer.stop() + + } } diff --git a/core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala index 8dca2ebb312f5..1538995a6b404 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala @@ -17,18 +17,22 @@ package org.apache.spark.storage +import java.io.IOException +import java.nio.ByteBuffer + +import scala.collection.mutable.ArrayBuffer +import scala.concurrent.future +import scala.concurrent.ExecutionContext.Implicits.global + import org.scalatest.{FunSuite, Matchers} -import org.scalatest.PrivateMethodTester._ import org.mockito.Mockito._ import org.mockito.Matchers.{any, eq => meq} import org.mockito.stubbing.Answer import org.mockito.invocation.InvocationOnMock -import org.apache.spark._ import org.apache.spark.storage.BlockFetcherIterator._ -import org.apache.spark.network.{ConnectionManager, ConnectionManagerId, - Message} +import org.apache.spark.network.{ConnectionManager, Message} class BlockFetcherIteratorSuite extends FunSuite with Matchers { @@ -137,4 +141,90 @@ class BlockFetcherIteratorSuite extends FunSuite with Matchers { assert(iterator.next._2.isDefined, "All elements should be defined but 5th element is not actually defined") } + test("block fetch from remote fails using BasicBlockFetcherIterator") { + val blockManager = mock(classOf[BlockManager]) + val connManager = mock(classOf[ConnectionManager]) + when(blockManager.connectionManager).thenReturn(connManager) + + val f = future { + throw new IOException("Send failed or we received an error ACK") + } + when(connManager.sendMessageReliably(any(), + any())).thenReturn(f) + when(blockManager.futureExecContext).thenReturn(global) + + when(blockManager.blockManagerId).thenReturn( + BlockManagerId("test-client", "test-client", 1, 0)) + when(blockManager.maxBytesInFlight).thenReturn(48 * 1024 * 1024) + + val blId1 = ShuffleBlockId(0,0,0) + val blId2 = ShuffleBlockId(0,1,0) + val bmId = BlockManagerId("test-server", "test-server",1 , 0) + val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( + (bmId, Seq((blId1, 1L), (blId2, 1L))) + ) + + val iterator = new BasicBlockFetcherIterator(blockManager, + blocksByAddress, null) + + iterator.initialize() + iterator.foreach{ + case (_, r) => { + (!r.isDefined) should be(true) + } + } + } + + test("block fetch from remote succeed using BasicBlockFetcherIterator") { + val blockManager = mock(classOf[BlockManager]) + val connManager = mock(classOf[ConnectionManager]) + when(blockManager.connectionManager).thenReturn(connManager) + + val blId1 = ShuffleBlockId(0,0,0) + val blId2 = ShuffleBlockId(0,1,0) + val buf1 = ByteBuffer.allocate(4) + val buf2 = ByteBuffer.allocate(4) + buf1.putInt(1) + buf1.flip() + buf2.putInt(1) + buf2.flip() + val blockMessage1 = BlockMessage.fromGotBlock(GotBlock(blId1, buf1)) + val blockMessage2 = BlockMessage.fromGotBlock(GotBlock(blId2, buf2)) + val blockMessageArray = new BlockMessageArray( + Seq(blockMessage1, blockMessage2)) + + val bufferMessage = blockMessageArray.toBufferMessage + val buffer = ByteBuffer.allocate(bufferMessage.size) + val arrayBuffer = new ArrayBuffer[ByteBuffer] + bufferMessage.buffers.foreach{ b => + buffer.put(b) + } + buffer.flip() + arrayBuffer += buffer + + val f = future { + Message.createBufferMessage(arrayBuffer) + } + when(connManager.sendMessageReliably(any(), + any())).thenReturn(f) + when(blockManager.futureExecContext).thenReturn(global) + + when(blockManager.blockManagerId).thenReturn( + BlockManagerId("test-client", "test-client", 1, 0)) + when(blockManager.maxBytesInFlight).thenReturn(48 * 1024 * 1024) + + val bmId = BlockManagerId("test-server", "test-server",1 , 0) + val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( + (bmId, Seq((blId1, 1L), (blId2, 1L))) + ) + + val iterator = new BasicBlockFetcherIterator(blockManager, + blocksByAddress, null) + iterator.initialize() + iterator.foreach{ + case (_, r) => { + (r.isDefined) should be(true) + } + } + } } diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 0ac0269d7cfc1..94bb2c445d2e9 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -25,7 +25,11 @@ import akka.actor._ import akka.pattern.ask import akka.util.Timeout -import org.mockito.Mockito.{mock, when} +import org.mockito.invocation.InvocationOnMock +import org.mockito.Matchers.any +import org.mockito.Mockito.{doAnswer, mock, spy, when} +import org.mockito.stubbing.Answer + import org.scalatest.{BeforeAndAfter, FunSuite, PrivateMethodTester} import org.scalatest.concurrent.Eventually._ import org.scalatest.concurrent.Timeouts._ @@ -33,6 +37,7 @@ import org.scalatest.Matchers import org.apache.spark.{MapOutputTrackerMaster, SecurityManager, SparkConf} import org.apache.spark.executor.DataReadMethod +import org.apache.spark.network.{Message, ConnectionManagerId} import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat @@ -1000,6 +1005,109 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter assert(!store.memoryStore.contains(rdd(1, 0)), "rdd_1_0 was in store") } + test("return error message when error occurred in BlockManagerWorker#onBlockMessageReceive") { + store = new BlockManager("", actorSystem, master, serializer, 1200, conf, + securityMgr, mapOutputTracker) + + val worker = spy(new BlockManagerWorker(store)) + val connManagerId = mock(classOf[ConnectionManagerId]) + + // setup request block messages + val reqBlId1 = ShuffleBlockId(0,0,0) + val reqBlId2 = ShuffleBlockId(0,1,0) + val reqBlockMessage1 = BlockMessage.fromGetBlock(GetBlock(reqBlId1)) + val reqBlockMessage2 = BlockMessage.fromGetBlock(GetBlock(reqBlId2)) + val reqBlockMessages = new BlockMessageArray( + Seq(reqBlockMessage1, reqBlockMessage2)) + val reqBufferMessage = reqBlockMessages.toBufferMessage + + val answer = new Answer[Option[BlockMessage]] { + override def answer(invocation: InvocationOnMock) + :Option[BlockMessage]= { + throw new Exception + } + } + + doAnswer(answer).when(worker).processBlockMessage(any()) + + // Test when exception was thrown during processing block messages + var ackMessage = worker.onBlockMessageReceive(reqBufferMessage, connManagerId) + + assert(ackMessage.isDefined, "When Exception was thrown in " + + "BlockManagerWorker#processBlockMessage, " + + "ackMessage should be defined") + assert(ackMessage.get.hasError, "When Exception was thown in " + + "BlockManagerWorker#processBlockMessage, " + + "ackMessage should have error") + + val notBufferMessage = mock(classOf[Message]) + + // Test when not BufferMessage was received + ackMessage = worker.onBlockMessageReceive(notBufferMessage, connManagerId) + assert(ackMessage.isDefined, "When not BufferMessage was passed to " + + "BlockManagerWorker#onBlockMessageReceive, " + + "ackMessage should be defined") + assert(ackMessage.get.hasError, "When not BufferMessage was passed to " + + "BlockManagerWorker#onBlockMessageReceive, " + + "ackMessage should have error") + } + + test("return ack message when no error occurred in BlocManagerWorker#onBlockMessageReceive") { + store = new BlockManager("", actorSystem, master, serializer, 1200, conf, + securityMgr, mapOutputTracker) + + val worker = spy(new BlockManagerWorker(store)) + val connManagerId = mock(classOf[ConnectionManagerId]) + + // setup request block messages + val reqBlId1 = ShuffleBlockId(0,0,0) + val reqBlId2 = ShuffleBlockId(0,1,0) + val reqBlockMessage1 = BlockMessage.fromGetBlock(GetBlock(reqBlId1)) + val reqBlockMessage2 = BlockMessage.fromGetBlock(GetBlock(reqBlId2)) + val reqBlockMessages = new BlockMessageArray( + Seq(reqBlockMessage1, reqBlockMessage2)) + + val tmpBufferMessage = reqBlockMessages.toBufferMessage + val buffer = ByteBuffer.allocate(tmpBufferMessage.size) + val arrayBuffer = new ArrayBuffer[ByteBuffer] + tmpBufferMessage.buffers.foreach{ b => + buffer.put(b) + } + buffer.flip() + arrayBuffer += buffer + val reqBufferMessage = Message.createBufferMessage(arrayBuffer) + + // setup ack block messages + val buf1 = ByteBuffer.allocate(4) + val buf2 = ByteBuffer.allocate(4) + buf1.putInt(1) + buf1.flip() + buf2.putInt(1) + buf2.flip() + val ackBlockMessage1 = BlockMessage.fromGotBlock(GotBlock(reqBlId1, buf1)) + val ackBlockMessage2 = BlockMessage.fromGotBlock(GotBlock(reqBlId2, buf2)) + + val answer = new Answer[Option[BlockMessage]] { + override def answer(invocation: InvocationOnMock) + :Option[BlockMessage]= { + if (invocation.getArguments()(0).asInstanceOf[BlockMessage].eq( + reqBlockMessage1)) { + return Some(ackBlockMessage1) + } else { + return Some(ackBlockMessage2) + } + } + } + + doAnswer(answer).when(worker).processBlockMessage(any()) + + val ackMessage = worker.onBlockMessageReceive(reqBufferMessage, connManagerId) + assert(ackMessage.isDefined, "When BlockManagerWorker#onBlockMessageReceive " + + "was executed successfully, ackMessage should be defined") + assert(!ackMessage.get.hasError, "When BlockManagerWorker#onBlockMessageReceive " + + "was executed successfully, ackMessage should not have error") + } + test("reserve/release unroll memory") { store = makeBlockManager(12000) val memoryStore = store.memoryStore From 40284a9a32a6efb6195098c93e292cbc6d128c42 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Wed, 6 Aug 2014 18:13:35 -0700 Subject: [PATCH 064/231] SPARK-2879 [BUILD] Use HTTPS to access Maven Central and other repos Maven Central has just now enabled HTTPS access for everyone to Maven Central (http://central.sonatype.org/articles/2014/Aug/03/https-support-launching-now/) This is timely, as a reminder of how easily an attacker can slip malicious code into a build that's downloading artifacts over HTTP (http://blog.ontoillogical.com/blog/2014/07/28/how-to-take-over-any-java-developer/). In the meantime, it looks like the Spring repo also now supports HTTPS, so can be used this way too. I propose to use HTTPS to access these repos. Author: Sean Owen Closes #1805 from srowen/SPARK-2879 and squashes the following commits: 7043a8e [Sean Owen] Use HTTPS for Maven Central libs and plugins; use id 'central' to override parent properly; use HTTPS for Spring repo (cherry picked from commit 4201d2711cd20a2892c40eb11102f73c2f826b2e) Signed-off-by: Patrick Wendell --- pom.xml | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/pom.xml b/pom.xml index 4ab027bad55c0..76bf6d8f902a8 100644 --- a/pom.xml +++ b/pom.xml @@ -143,11 +143,11 @@ - maven-repo + central Maven Repository - http://repo.maven.apache.org/maven2 + https://repo.maven.apache.org/maven2 true @@ -213,7 +213,7 @@ spring-releases Spring Release Repository - http://repo.spring.io/libs-release + https://repo.spring.io/libs-release true @@ -222,6 +222,15 @@ + + + central + https://repo1.maven.org/maven2 + + true + + + From 53fa0486af202b76dfea08d541c5d874731f81fb Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Wed, 6 Aug 2014 18:45:03 -0700 Subject: [PATCH 065/231] HOTFIX: Support custom Java 7 location --- dev/create-release/create-release.sh | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/dev/create-release/create-release.sh b/dev/create-release/create-release.sh index 42473629d4f15..1867cf4ec46ca 100755 --- a/dev/create-release/create-release.sh +++ b/dev/create-release/create-release.sh @@ -35,6 +35,12 @@ RELEASE_VERSION=${RELEASE_VERSION:-1.0.0} RC_NAME=${RC_NAME:-rc2} USER_NAME=${USER_NAME:-pwendell} +if [ -z "$JAVA_HOME" ]; then + echo "Error: JAVA_HOME is not set, cannot proceed." + exit -1 +fi +JAVA_7_HOME=${JAVA_7_HOME:-$JAVA_HOME} + set -e GIT_TAG=v$RELEASE_VERSION-$RC_NAME @@ -130,7 +136,8 @@ scp spark-* \ cd spark sbt/sbt clean cd docs -PRODUCTION=1 jekyll build +# Compile docs with Java 7 to use nicer format +JAVA_HOME=$JAVA_7_HOME PRODUCTION=1 jekyll build echo "Copying release documentation" rc_docs_folder=${rc_folder}-docs ssh $USER_NAME@people.apache.org \ From cf35b56d4daed1bb4de3084825842fc750c830f1 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Wed, 6 Aug 2014 19:11:39 -0700 Subject: [PATCH 066/231] Updating versions for Spark 1.1.0 --- core/src/main/scala/org/apache/spark/SparkContext.scala | 2 +- docs/_config.yml | 4 ++-- ec2/spark_ec2.py | 2 +- extras/java8-tests/pom.xml | 2 +- python/epydoc.conf | 2 +- python/pyspark/shell.py | 2 +- .../src/main/scala/org/apache/spark/repl/SparkILoopInit.scala | 2 +- yarn/alpha/pom.xml | 2 +- 8 files changed, 9 insertions(+), 9 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index e132955f0f850..0470fbeed1ada 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -1312,7 +1312,7 @@ class SparkContext(config: SparkConf) extends Logging { */ object SparkContext extends Logging { - private[spark] val SPARK_VERSION = "1.0.0" + private[spark] val SPARK_VERSION = "1.1.0" private[spark] val SPARK_JOB_DESCRIPTION = "spark.job.description" diff --git a/docs/_config.yml b/docs/_config.yml index 45b78fe724a50..84db61876b82d 100644 --- a/docs/_config.yml +++ b/docs/_config.yml @@ -3,8 +3,8 @@ markdown: kramdown # These allow the documentation to be updated with nerw releases # of Spark, Scala, and Mesos. -SPARK_VERSION: 1.0.0-SNAPSHOT -SPARK_VERSION_SHORT: 1.0.0 +SPARK_VERSION: 1.1.0-SNAPSHOT +SPARK_VERSION_SHORT: 1.1.0 SCALA_BINARY_VERSION: "2.10" SCALA_VERSION: "2.10.4" MESOS_VERSION: 0.18.1 diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py index 0c2f85a3868f4..fc6fb1db59424 100755 --- a/ec2/spark_ec2.py +++ b/ec2/spark_ec2.py @@ -196,7 +196,7 @@ def is_active(instance): def get_spark_shark_version(opts): spark_shark_map = { "0.7.3": "0.7.1", "0.8.0": "0.8.0", "0.8.1": "0.8.1", "0.9.0": "0.9.0", "0.9.1": "0.9.1", - "1.0.0": "1.0.0" + "1.0.0": "1.0.0", "1.0.1": "1.0.1", "1.0.2": "1.0.2", "1.1.0": "1.1.0" } version = opts.spark_version.replace("v", "") if version not in spark_shark_map: diff --git a/extras/java8-tests/pom.xml b/extras/java8-tests/pom.xml index 5308bb4e440ea..8ce7b94fee1be 100644 --- a/extras/java8-tests/pom.xml +++ b/extras/java8-tests/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent - 1.1.0-SNAPSHOT + 1.1.0 ../../pom.xml diff --git a/python/epydoc.conf b/python/epydoc.conf index 51c0faf359939..d066ecb7712c5 100644 --- a/python/epydoc.conf +++ b/python/epydoc.conf @@ -18,7 +18,7 @@ # # Information about the project. -name: Spark 1.0.0 Python API Docs +name: Spark 1.1.0 Python API Docs url: http://spark.apache.org # The list of modules to document. Modules can be named using diff --git a/python/pyspark/shell.py b/python/pyspark/shell.py index e1e7cd954189f..8a9777a708b33 100644 --- a/python/pyspark/shell.py +++ b/python/pyspark/shell.py @@ -47,7 +47,7 @@ ____ __ / __/__ ___ _____/ /__ _\ \/ _ \/ _ `/ __/ '_/ - /__ / .__/\_,_/_/ /_/\_\ version 1.0.0-SNAPSHOT + /__ / .__/\_,_/_/ /_/\_\ version 1.1.0 /_/ """) print("Using Python version %s (%s, %s)" % ( diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala b/repl/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala index 910b31d209e13..74c58eb49d07f 100644 --- a/repl/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala +++ b/repl/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala @@ -26,7 +26,7 @@ trait SparkILoopInit { ____ __ / __/__ ___ _____/ /__ _\ \/ _ \/ _ `/ __/ '_/ - /___/ .__/\_,_/_/ /_/\_\ version 1.0.0-SNAPSHOT + /___/ .__/\_,_/_/ /_/\_\ version 1.1.0 /_/ """) import Properties._ diff --git a/yarn/alpha/pom.xml b/yarn/alpha/pom.xml index 51744ece0412d..72d9b1606ad9c 100644 --- a/yarn/alpha/pom.xml +++ b/yarn/alpha/pom.xml @@ -20,7 +20,7 @@ org.apache.spark yarn-parent_2.10 - 1.1.0-SNAPSHOT + 1.1.0 ../pom.xml From d428d88418d385d1d04e1b0adcb6b068efe9c7b0 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Thu, 7 Aug 2014 03:16:14 +0000 Subject: [PATCH 067/231] [maven-release-plugin] prepare release v1.1.0-snapshot1 --- assembly/pom.xml | 6 +++--- bagel/pom.xml | 2 +- core/pom.xml | 2 +- examples/pom.xml | 2 +- external/flume-sink/pom.xml | 2 +- external/flume/pom.xml | 2 +- external/kafka/pom.xml | 2 +- external/mqtt/pom.xml | 2 +- external/twitter/pom.xml | 2 +- external/zeromq/pom.xml | 2 +- extras/kinesis-asl/pom.xml | 2 +- extras/spark-ganglia-lgpl/pom.xml | 2 +- graphx/pom.xml | 2 +- mllib/pom.xml | 2 +- pom.xml | 9 ++++----- repl/pom.xml | 2 +- sql/catalyst/pom.xml | 5 ++--- sql/core/pom.xml | 5 ++--- sql/hive-thriftserver/pom.xml | 5 ++--- sql/hive/pom.xml | 5 ++--- streaming/pom.xml | 2 +- tools/pom.xml | 2 +- yarn/pom.xml | 2 +- yarn/stable/pom.xml | 2 +- 24 files changed, 33 insertions(+), 38 deletions(-) diff --git a/assembly/pom.xml b/assembly/pom.xml index 703f15925bc44..02dd3d15337e4 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.0-SNAPSHOT + 1.1.0 ../pom.xml @@ -124,8 +124,8 @@ log4j.properties - - + +
    diff --git a/bagel/pom.xml b/bagel/pom.xml index bd51b112e26fa..8eec7e5dd23b5 100644 --- a/bagel/pom.xml +++ b/bagel/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.0-SNAPSHOT + 1.1.0 ../pom.xml diff --git a/core/pom.xml b/core/pom.xml index 6d8be37037729..83e60268afbd2 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.0-SNAPSHOT + 1.1.0 ../pom.xml diff --git a/examples/pom.xml b/examples/pom.xml index 8c4c128bb484d..9bde90ed116e1 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.0-SNAPSHOT + 1.1.0 ../pom.xml diff --git a/external/flume-sink/pom.xml b/external/flume-sink/pom.xml index d0bf1cf1ea796..fd317e70d0c8e 100644 --- a/external/flume-sink/pom.xml +++ b/external/flume-sink/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.0-SNAPSHOT + 1.1.0 ../../pom.xml diff --git a/external/flume/pom.xml b/external/flume/pom.xml index c532705f3950c..6563f4d73da01 100644 --- a/external/flume/pom.xml +++ b/external/flume/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.0-SNAPSHOT + 1.1.0 ../../pom.xml diff --git a/external/kafka/pom.xml b/external/kafka/pom.xml index 4e2275ab238f7..e3df55355c8d5 100644 --- a/external/kafka/pom.xml +++ b/external/kafka/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.0-SNAPSHOT + 1.1.0 ../../pom.xml diff --git a/external/mqtt/pom.xml b/external/mqtt/pom.xml index dc48a08c93de2..1f9e52b19b24e 100644 --- a/external/mqtt/pom.xml +++ b/external/mqtt/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.0-SNAPSHOT + 1.1.0 ../../pom.xml diff --git a/external/twitter/pom.xml b/external/twitter/pom.xml index b93ad016f84f0..d28741a9524dc 100644 --- a/external/twitter/pom.xml +++ b/external/twitter/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.0-SNAPSHOT + 1.1.0 ../../pom.xml diff --git a/external/zeromq/pom.xml b/external/zeromq/pom.xml index 22c1fff23d9a2..7e7609985d2c0 100644 --- a/external/zeromq/pom.xml +++ b/external/zeromq/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.0-SNAPSHOT + 1.1.0 ../../pom.xml diff --git a/extras/kinesis-asl/pom.xml b/extras/kinesis-asl/pom.xml index a54b34235dfb4..4d0eac8956955 100644 --- a/extras/kinesis-asl/pom.xml +++ b/extras/kinesis-asl/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent - 1.1.0-SNAPSHOT + 1.1.0 ../../pom.xml diff --git a/extras/spark-ganglia-lgpl/pom.xml b/extras/spark-ganglia-lgpl/pom.xml index a5b162a0482e4..44f50aeefa62e 100644 --- a/extras/spark-ganglia-lgpl/pom.xml +++ b/extras/spark-ganglia-lgpl/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent - 1.1.0-SNAPSHOT + 1.1.0 ../../pom.xml diff --git a/graphx/pom.xml b/graphx/pom.xml index 6dd52fc618b1e..c677f361cf1f2 100644 --- a/graphx/pom.xml +++ b/graphx/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.0-SNAPSHOT + 1.1.0 ../pom.xml diff --git a/mllib/pom.xml b/mllib/pom.xml index 9a33bd1cf6ad1..9d92b0b4dc7b6 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.0-SNAPSHOT + 1.1.0 ../pom.xml diff --git a/pom.xml b/pom.xml index 76bf6d8f902a8..a29de043d0dff 100644 --- a/pom.xml +++ b/pom.xml @@ -16,8 +16,7 @@ ~ limitations under the License. --> - + 4.0.0 org.apache @@ -26,7 +25,7 @@ org.apache.spark spark-parent - 1.1.0-SNAPSHOT + 1.1.0 pom Spark Project Parent POM http://spark.apache.org/ @@ -41,7 +40,7 @@ scm:git:git@github.com:apache/spark.git scm:git:https://git-wip-us.apache.org/repos/asf/spark.git scm:git:git@github.com:apache/spark.git - HEAD + v1.1.0-snapshot1 @@ -878,7 +877,7 @@ . ${project.build.directory}/SparkTestSuite.txt -Xmx3g -XX:MaxPermSize=${MaxPermGen} -XX:ReservedCodeCacheSize=512m - + true ${session.executionRootDirectory} diff --git a/repl/pom.xml b/repl/pom.xml index 68f4504450778..b7458eeb270dd 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.0-SNAPSHOT + 1.1.0 ../pom.xml diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index 58d44e7923bee..7c9e5b284e0d9 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -16,13 +16,12 @@ ~ limitations under the License. --> - + 4.0.0 org.apache.spark spark-parent - 1.1.0-SNAPSHOT + 1.1.0 ../../pom.xml diff --git a/sql/core/pom.xml b/sql/core/pom.xml index c8016e41256d5..d797753f12151 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -16,13 +16,12 @@ ~ limitations under the License. --> - + 4.0.0 org.apache.spark spark-parent - 1.1.0-SNAPSHOT + 1.1.0 ../../pom.xml diff --git a/sql/hive-thriftserver/pom.xml b/sql/hive-thriftserver/pom.xml index c6f60c18804a4..d75d2e514544d 100644 --- a/sql/hive-thriftserver/pom.xml +++ b/sql/hive-thriftserver/pom.xml @@ -16,13 +16,12 @@ ~ limitations under the License. --> - + 4.0.0 org.apache.spark spark-parent - 1.1.0-SNAPSHOT + 1.1.0 ../../pom.xml diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index 93d00f7c37c9b..9e8989e55ef40 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -16,13 +16,12 @@ ~ limitations under the License. --> - + 4.0.0 org.apache.spark spark-parent - 1.1.0-SNAPSHOT + 1.1.0 ../../pom.xml diff --git a/streaming/pom.xml b/streaming/pom.xml index 1072f74aea0d9..31c096380a7c1 100644 --- a/streaming/pom.xml +++ b/streaming/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.0-SNAPSHOT + 1.1.0 ../pom.xml diff --git a/tools/pom.xml b/tools/pom.xml index 97abb6b2b63e0..ce3629443ed98 100644 --- a/tools/pom.xml +++ b/tools/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent - 1.1.0-SNAPSHOT + 1.1.0 ../pom.xml diff --git a/yarn/pom.xml b/yarn/pom.xml index 3faaf053634d6..274be3a563641 100644 --- a/yarn/pom.xml +++ b/yarn/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent - 1.1.0-SNAPSHOT + 1.1.0 ../pom.xml diff --git a/yarn/stable/pom.xml b/yarn/stable/pom.xml index b6c8456d06684..64fb00ac71b60 100644 --- a/yarn/stable/pom.xml +++ b/yarn/stable/pom.xml @@ -20,7 +20,7 @@ org.apache.spark yarn-parent_2.10 - 1.1.0-SNAPSHOT + 1.1.0 ../pom.xml From c204a742a9eb9d3fd318e0f059bd00cbfb8b2c14 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Thu, 7 Aug 2014 03:16:23 +0000 Subject: [PATCH 068/231] [maven-release-plugin] prepare for next development iteration --- assembly/pom.xml | 2 +- bagel/pom.xml | 2 +- core/pom.xml | 2 +- examples/pom.xml | 2 +- external/flume-sink/pom.xml | 2 +- external/flume/pom.xml | 2 +- external/kafka/pom.xml | 2 +- external/mqtt/pom.xml | 2 +- external/twitter/pom.xml | 2 +- external/zeromq/pom.xml | 2 +- extras/kinesis-asl/pom.xml | 2 +- extras/spark-ganglia-lgpl/pom.xml | 2 +- graphx/pom.xml | 2 +- mllib/pom.xml | 2 +- pom.xml | 4 ++-- repl/pom.xml | 2 +- sql/catalyst/pom.xml | 2 +- sql/core/pom.xml | 2 +- sql/hive-thriftserver/pom.xml | 2 +- sql/hive/pom.xml | 2 +- streaming/pom.xml | 2 +- tools/pom.xml | 2 +- yarn/pom.xml | 2 +- yarn/stable/pom.xml | 2 +- 24 files changed, 25 insertions(+), 25 deletions(-) diff --git a/assembly/pom.xml b/assembly/pom.xml index 02dd3d15337e4..16e5271b35050 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.0 + 1.1.1-SNAPSHOT ../pom.xml diff --git a/bagel/pom.xml b/bagel/pom.xml index 8eec7e5dd23b5..f29540b239c73 100644 --- a/bagel/pom.xml +++ b/bagel/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.0 + 1.1.1-SNAPSHOT ../pom.xml diff --git a/core/pom.xml b/core/pom.xml index 83e60268afbd2..debc4dd703d9a 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.0 + 1.1.1-SNAPSHOT ../pom.xml diff --git a/examples/pom.xml b/examples/pom.xml index 9bde90ed116e1..f35d3d6a788e3 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.0 + 1.1.1-SNAPSHOT ../pom.xml diff --git a/external/flume-sink/pom.xml b/external/flume-sink/pom.xml index fd317e70d0c8e..cfbf943bdafe0 100644 --- a/external/flume-sink/pom.xml +++ b/external/flume-sink/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.0 + 1.1.1-SNAPSHOT ../../pom.xml diff --git a/external/flume/pom.xml b/external/flume/pom.xml index 6563f4d73da01..b127136e3f5a0 100644 --- a/external/flume/pom.xml +++ b/external/flume/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.0 + 1.1.1-SNAPSHOT ../../pom.xml diff --git a/external/kafka/pom.xml b/external/kafka/pom.xml index e3df55355c8d5..5123d0554639c 100644 --- a/external/kafka/pom.xml +++ b/external/kafka/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.0 + 1.1.1-SNAPSHOT ../../pom.xml diff --git a/external/mqtt/pom.xml b/external/mqtt/pom.xml index 1f9e52b19b24e..9c00bfc8429a4 100644 --- a/external/mqtt/pom.xml +++ b/external/mqtt/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.0 + 1.1.1-SNAPSHOT ../../pom.xml diff --git a/external/twitter/pom.xml b/external/twitter/pom.xml index d28741a9524dc..1b9ef4af0c2ed 100644 --- a/external/twitter/pom.xml +++ b/external/twitter/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.0 + 1.1.1-SNAPSHOT ../../pom.xml diff --git a/external/zeromq/pom.xml b/external/zeromq/pom.xml index 7e7609985d2c0..60292a2683212 100644 --- a/external/zeromq/pom.xml +++ b/external/zeromq/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.0 + 1.1.1-SNAPSHOT ../../pom.xml diff --git a/extras/kinesis-asl/pom.xml b/extras/kinesis-asl/pom.xml index 4d0eac8956955..58b995c5e7005 100644 --- a/extras/kinesis-asl/pom.xml +++ b/extras/kinesis-asl/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent - 1.1.0 + 1.1.1-SNAPSHOT ../../pom.xml diff --git a/extras/spark-ganglia-lgpl/pom.xml b/extras/spark-ganglia-lgpl/pom.xml index 44f50aeefa62e..02c9676fb086a 100644 --- a/extras/spark-ganglia-lgpl/pom.xml +++ b/extras/spark-ganglia-lgpl/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent - 1.1.0 + 1.1.1-SNAPSHOT ../../pom.xml diff --git a/graphx/pom.xml b/graphx/pom.xml index c677f361cf1f2..656478583fac2 100644 --- a/graphx/pom.xml +++ b/graphx/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.0 + 1.1.1-SNAPSHOT ../pom.xml diff --git a/mllib/pom.xml b/mllib/pom.xml index 9d92b0b4dc7b6..d78fed794470c 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.0 + 1.1.1-SNAPSHOT ../pom.xml diff --git a/pom.xml b/pom.xml index a29de043d0dff..dcda3d53b5cb2 100644 --- a/pom.xml +++ b/pom.xml @@ -25,7 +25,7 @@ org.apache.spark spark-parent - 1.1.0 + 1.1.1-SNAPSHOT pom Spark Project Parent POM http://spark.apache.org/ @@ -40,7 +40,7 @@ scm:git:git@github.com:apache/spark.git scm:git:https://git-wip-us.apache.org/repos/asf/spark.git scm:git:git@github.com:apache/spark.git - v1.1.0-snapshot1 + HEAD diff --git a/repl/pom.xml b/repl/pom.xml index b7458eeb270dd..8748ada36f57a 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.0 + 1.1.1-SNAPSHOT ../pom.xml diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index 7c9e5b284e0d9..e2356381c07fb 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.0 + 1.1.1-SNAPSHOT ../../pom.xml diff --git a/sql/core/pom.xml b/sql/core/pom.xml index d797753f12151..3efea9ab8b247 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.0 + 1.1.1-SNAPSHOT ../../pom.xml diff --git a/sql/hive-thriftserver/pom.xml b/sql/hive-thriftserver/pom.xml index d75d2e514544d..c264ff4ec92e5 100644 --- a/sql/hive-thriftserver/pom.xml +++ b/sql/hive-thriftserver/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.0 + 1.1.1-SNAPSHOT ../../pom.xml diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index 9e8989e55ef40..c18a664e737c8 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.0 + 1.1.1-SNAPSHOT ../../pom.xml diff --git a/streaming/pom.xml b/streaming/pom.xml index 31c096380a7c1..c0ce0d7c7478d 100644 --- a/streaming/pom.xml +++ b/streaming/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.0 + 1.1.1-SNAPSHOT ../pom.xml diff --git a/tools/pom.xml b/tools/pom.xml index ce3629443ed98..c601fd5fbbee2 100644 --- a/tools/pom.xml +++ b/tools/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent - 1.1.0 + 1.1.1-SNAPSHOT ../pom.xml diff --git a/yarn/pom.xml b/yarn/pom.xml index 274be3a563641..18f27b827ff1a 100644 --- a/yarn/pom.xml +++ b/yarn/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent - 1.1.0 + 1.1.1-SNAPSHOT ../pom.xml diff --git a/yarn/stable/pom.xml b/yarn/stable/pom.xml index 64fb00ac71b60..2ba3baf0e3b2e 100644 --- a/yarn/stable/pom.xml +++ b/yarn/stable/pom.xml @@ -20,7 +20,7 @@ org.apache.spark yarn-parent_2.10 - 1.1.0 + 1.1.1-SNAPSHOT ../pom.xml From cc8a7e97e1c9190fcb6093ad9c94e7f0730af94c Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 6 Aug 2014 21:22:13 -0700 Subject: [PATCH 069/231] [SPARK-2887] fix bug of countApproxDistinct() when have more than one partition fix bug of countApproxDistinct() when have more than one partition Author: Davies Liu Closes #1812 from davies/approx and squashes the following commits: bf757ce [Davies Liu] fix bug of countApproxDistinct() when have more than one partition (cherry picked from commit ffd1f59a62a9dd9a4d5a7b09490b9d01ff1cd42d) Signed-off-by: Patrick Wendell --- core/src/main/scala/org/apache/spark/rdd/RDD.scala | 2 +- .../src/test/scala/org/apache/spark/rdd/RDDSuite.scala | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index e1c49e35abecd..0159003c88e06 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -1004,7 +1004,7 @@ abstract class RDD[T: ClassTag]( }, (h1: HyperLogLogPlus, h2: HyperLogLogPlus) => { h1.addAll(h2) - h2 + h1 }).cardinality() } diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index b31e3a09e5b9c..4a7dc8dca25e2 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -81,11 +81,11 @@ class RDDSuite extends FunSuite with SharedSparkContext { def error(est: Long, size: Long) = math.abs(est - size) / size.toDouble - val size = 100 - val uniformDistro = for (i <- 1 to 100000) yield i % size - val simpleRdd = sc.makeRDD(uniformDistro) - assert(error(simpleRdd.countApproxDistinct(4, 0), size) < 0.4) - assert(error(simpleRdd.countApproxDistinct(8, 0), size) < 0.1) + val size = 1000 + val uniformDistro = for (i <- 1 to 5000) yield i % size + val simpleRdd = sc.makeRDD(uniformDistro, 10) + assert(error(simpleRdd.countApproxDistinct(8, 0), size) < 0.2) + assert(error(simpleRdd.countApproxDistinct(12, 0), size) < 0.1) } test("SparkContext.union") { From c9f09445878de462282b02855bda66072458bd5c Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Wed, 6 Aug 2014 22:58:59 -0700 Subject: [PATCH 070/231] [SPARK-2851] [mllib] DecisionTree Python consistency update Added 6 static train methods to match Python API, but without default arguments (but with Python default args noted in docs). Added factory classes for Algo and Impurity, but made private[mllib]. CC: mengxr dorx Please let me know if there are other changes which would help with API consistency---thanks! Author: Joseph K. Bradley Closes #1798 from jkbradley/dt-python-consistency and squashes the following commits: 6f7edf8 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-python-consistency a0d7dbe [Joseph K. Bradley] DecisionTree: In Java-friendly train* methods, changed to use JavaRDD instead of RDD. ee1d236 [Joseph K. Bradley] DecisionTree API updates: * Removed train() function in Python API (tree.py) ** Removed corresponding function in Scala/Java API (the ones taking basic types) 00f820e [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-python-consistency fe6dbfa [Joseph K. Bradley] removed unnecessary imports e358661 [Joseph K. Bradley] DecisionTree API change: * Added 6 static train methods to match Python API, but without default arguments (but with Python default args noted in docs). c699850 [Joseph K. Bradley] a few doc comments eaf84c0 [Joseph K. Bradley] Added DecisionTree static train() methods API to match Python, but without default parameters (cherry picked from commit 47ccd5e71be49b723476f3ff8d5768f0f45c2ea6) Signed-off-by: Xiangrui Meng --- .../mllib/api/python/PythonMLLibAPI.scala | 19 +-- .../spark/mllib/tree/DecisionTree.scala | 151 ++++++++++++++---- .../spark/mllib/tree/configuration/Algo.scala | 6 + .../mllib/tree/impurity/Impurities.scala | 32 ++++ python/pyspark/mllib/tree.py | 50 ++---- 5 files changed, 181 insertions(+), 77 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurities.scala diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index fd0b9556c7d54..ba7ccd8ce4b8b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -25,16 +25,14 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} import org.apache.spark.mllib.classification._ import org.apache.spark.mllib.clustering._ -import org.apache.spark.mllib.linalg.{SparseVector, Vector, Vectors} import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.linalg.{Matrix, SparseVector, Vector, Vectors} import org.apache.spark.mllib.random.{RandomRDDGenerators => RG} import org.apache.spark.mllib.recommendation._ import org.apache.spark.mllib.regression._ -import org.apache.spark.mllib.tree.configuration.Algo._ -import org.apache.spark.mllib.tree.configuration.Strategy +import org.apache.spark.mllib.tree.configuration.{Algo, Strategy} import org.apache.spark.mllib.tree.DecisionTree -import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Impurity, Variance} +import org.apache.spark.mllib.tree.impurity._ import org.apache.spark.mllib.tree.model.DecisionTreeModel import org.apache.spark.mllib.stat.Statistics import org.apache.spark.mllib.stat.correlation.CorrelationNames @@ -523,17 +521,8 @@ class PythonMLLibAPI extends Serializable { val data = dataBytesJRDD.rdd.map(deserializeLabeledPoint) - val algo: Algo = algoStr match { - case "classification" => Classification - case "regression" => Regression - case _ => throw new IllegalArgumentException(s"Bad algoStr parameter: $algoStr") - } - val impurity: Impurity = impurityStr match { - case "gini" => Gini - case "entropy" => Entropy - case "variance" => Variance - case _ => throw new IllegalArgumentException(s"Bad impurityStr parameter: $impurityStr") - } + val algo = Algo.fromString(algoStr) + val impurity = Impurities.fromString(impurityStr) val strategy = new Strategy( algo = algo, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 1d03e6e3b36cf..c8a865659682f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -17,14 +17,18 @@ package org.apache.spark.mllib.tree +import org.apache.spark.api.java.JavaRDD + +import scala.collection.JavaConverters._ + import org.apache.spark.annotation.Experimental import org.apache.spark.Logging import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.tree.configuration.Strategy +import org.apache.spark.mllib.tree.configuration.{Algo, Strategy} import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.FeatureType._ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ -import org.apache.spark.mllib.tree.impurity.Impurity +import org.apache.spark.mllib.tree.impurity.{Impurities, Gini, Entropy, Impurity} import org.apache.spark.mllib.tree.model._ import org.apache.spark.rdd.RDD import org.apache.spark.util.random.XORShiftRandom @@ -200,6 +204,10 @@ object DecisionTree extends Serializable with Logging { * Method to train a decision tree model. * The method supports binary and multiclass classification and regression. * + * Note: Using [[org.apache.spark.mllib.tree.DecisionTree$#trainClassifier]] + * and [[org.apache.spark.mllib.tree.DecisionTree$#trainRegressor]] + * is recommended to clearly separate classification and regression. + * * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. * For classification, labels should take values {0, 1, ..., numClasses-1}. * For regression, labels are real numbers. @@ -213,10 +221,12 @@ object DecisionTree extends Serializable with Logging { } /** - * Method to train a decision tree model where the instances are represented as an RDD of - * (label, features) pairs. The method supports binary classification and regression. For the - * binary classification, the label for each instance should either be 0 or 1 to denote the two - * classes. + * Method to train a decision tree model. + * The method supports binary and multiclass classification and regression. + * + * Note: Using [[org.apache.spark.mllib.tree.DecisionTree$#trainClassifier]] + * and [[org.apache.spark.mllib.tree.DecisionTree$#trainRegressor]] + * is recommended to clearly separate classification and regression. * * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. * For classification, labels should take values {0, 1, ..., numClasses-1}. @@ -237,10 +247,12 @@ object DecisionTree extends Serializable with Logging { } /** - * Method to train a decision tree model where the instances are represented as an RDD of - * (label, features) pairs. The method supports binary classification and regression. For the - * binary classification, the label for each instance should either be 0 or 1 to denote the two - * classes. + * Method to train a decision tree model. + * The method supports binary and multiclass classification and regression. + * + * Note: Using [[org.apache.spark.mllib.tree.DecisionTree$#trainClassifier]] + * and [[org.apache.spark.mllib.tree.DecisionTree$#trainRegressor]] + * is recommended to clearly separate classification and regression. * * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. * For classification, labels should take values {0, 1, ..., numClasses-1}. @@ -263,11 +275,12 @@ object DecisionTree extends Serializable with Logging { } /** - * Method to train a decision tree model where the instances are represented as an RDD of - * (label, features) pairs. The decision tree method supports binary classification and - * regression. For the binary classification, the label for each instance should either be 0 or - * 1 to denote the two classes. The method also supports categorical features inputs where the - * number of categories can specified using the categoricalFeaturesInfo option. + * Method to train a decision tree model. + * The method supports binary and multiclass classification and regression. + * + * Note: Using [[org.apache.spark.mllib.tree.DecisionTree$#trainClassifier]] + * and [[org.apache.spark.mllib.tree.DecisionTree$#trainRegressor]] + * is recommended to clearly separate classification and regression. * * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. * For classification, labels should take values {0, 1, ..., numClasses-1}. @@ -279,11 +292,9 @@ object DecisionTree extends Serializable with Logging { * @param numClassesForClassification number of classes for classification. Default value of 2. * @param maxBins maximum number of bins used for splitting features * @param quantileCalculationStrategy algorithm for calculating quantiles - * @param categoricalFeaturesInfo A map storing information about the categorical variables and - * the number of discrete values they take. For example, - * an entry (n -> k) implies the feature n is categorical with k - * categories 0, 1, 2, ... , k-1. It's important to note that - * features are zero-indexed. + * @param categoricalFeaturesInfo Map storing arity of categorical features. + * E.g., an entry (n -> k) indicates that feature n is categorical + * with k categories indexed from 0: {0, 1, ..., k-1}. * @return DecisionTreeModel that can be used for prediction */ def train( @@ -300,6 +311,93 @@ object DecisionTree extends Serializable with Logging { new DecisionTree(strategy).train(input) } + /** + * Method to train a decision tree model for binary or multiclass classification. + * + * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * Labels should take values {0, 1, ..., numClasses-1}. + * @param numClassesForClassification number of classes for classification. + * @param categoricalFeaturesInfo Map storing arity of categorical features. + * E.g., an entry (n -> k) indicates that feature n is categorical + * with k categories indexed from 0: {0, 1, ..., k-1}. + * @param impurity Criterion used for information gain calculation. + * Supported values: "gini" (recommended) or "entropy". + * @param maxDepth Maximum depth of the tree. + * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. + * (suggested value: 4) + * @param maxBins maximum number of bins used for splitting features + * (suggested value: 100) + * @return DecisionTreeModel that can be used for prediction + */ + def trainClassifier( + input: RDD[LabeledPoint], + numClassesForClassification: Int, + categoricalFeaturesInfo: Map[Int, Int], + impurity: String, + maxDepth: Int, + maxBins: Int): DecisionTreeModel = { + val impurityType = Impurities.fromString(impurity) + train(input, Classification, impurityType, maxDepth, numClassesForClassification, maxBins, Sort, + categoricalFeaturesInfo) + } + + /** + * Java-friendly API for [[org.apache.spark.mllib.tree.DecisionTree$#trainClassifier]] + */ + def trainClassifier( + input: JavaRDD[LabeledPoint], + numClassesForClassification: Int, + categoricalFeaturesInfo: java.util.Map[java.lang.Integer, java.lang.Integer], + impurity: String, + maxDepth: Int, + maxBins: Int): DecisionTreeModel = { + trainClassifier(input.rdd, numClassesForClassification, + categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap, + impurity, maxDepth, maxBins) + } + + /** + * Method to train a decision tree model for regression. + * + * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * Labels are real numbers. + * @param categoricalFeaturesInfo Map storing arity of categorical features. + * E.g., an entry (n -> k) indicates that feature n is categorical + * with k categories indexed from 0: {0, 1, ..., k-1}. + * @param impurity Criterion used for information gain calculation. + * Supported values: "variance". + * @param maxDepth Maximum depth of the tree. + * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. + * (suggested value: 4) + * @param maxBins maximum number of bins used for splitting features + * (suggested value: 100) + * @return DecisionTreeModel that can be used for prediction + */ + def trainRegressor( + input: RDD[LabeledPoint], + categoricalFeaturesInfo: Map[Int, Int], + impurity: String, + maxDepth: Int, + maxBins: Int): DecisionTreeModel = { + val impurityType = Impurities.fromString(impurity) + train(input, Regression, impurityType, maxDepth, 0, maxBins, Sort, categoricalFeaturesInfo) + } + + /** + * Java-friendly API for [[org.apache.spark.mllib.tree.DecisionTree$#trainRegressor]] + */ + def trainRegressor( + input: JavaRDD[LabeledPoint], + categoricalFeaturesInfo: java.util.Map[java.lang.Integer, java.lang.Integer], + impurity: String, + maxDepth: Int, + maxBins: Int): DecisionTreeModel = { + trainRegressor(input.rdd, + categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap, + impurity, maxDepth, maxBins) + } + + private val InvalidBinIndex = -1 /** @@ -1331,16 +1429,15 @@ object DecisionTree extends Serializable with Logging { * Categorical features: * For each feature, there is 1 bin per split. * Splits and bins are handled in 2 ways: - * (a) For multiclass classification with a low-arity feature + * (a) "unordered features" + * For multiclass classification with a low-arity feature * (i.e., if isMulticlass && isSpaceSufficientForAllCategoricalSplits), * the feature is split based on subsets of categories. - * There are 2^(maxFeatureValue - 1) - 1 splits. - * (b) For regression and binary classification, + * There are math.pow(2, maxFeatureValue - 1) - 1 splits. + * (b) "ordered features" + * For regression and binary classification, * and for multiclass classification with a high-arity feature, - * there is one split per category. - - * Categorical case (a) features are called unordered features. - * Other cases are called ordered features. + * there is one bin per category. * * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] * @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala index 79a01f58319e8..0ef9c6181a0a0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala @@ -27,4 +27,10 @@ import org.apache.spark.annotation.Experimental object Algo extends Enumeration { type Algo = Value val Classification, Regression = Value + + private[mllib] def fromString(name: String): Algo = name match { + case "classification" => Classification + case "regression" => Regression + case _ => throw new IllegalArgumentException(s"Did not recognize Algo name: $name") + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurities.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurities.scala new file mode 100644 index 0000000000000..9a6452aa13a61 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurities.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.impurity + +/** + * Factory for Impurity instances. + */ +private[mllib] object Impurities { + + def fromString(name: String): Impurity = name match { + case "gini" => Gini + case "entropy" => Entropy + case "variance" => Variance + case _ => throw new IllegalArgumentException(s"Did not recognize Impurity name: $name") + } + +} diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py index 2518001ea0b93..e1a4671709b7d 100644 --- a/python/pyspark/mllib/tree.py +++ b/python/pyspark/mllib/tree.py @@ -131,7 +131,7 @@ class DecisionTree(object): """ @staticmethod - def trainClassifier(data, numClasses, categoricalFeaturesInfo={}, + def trainClassifier(data, numClasses, categoricalFeaturesInfo, impurity="gini", maxDepth=4, maxBins=100): """ Train a DecisionTreeModel for classification. @@ -150,12 +150,20 @@ def trainClassifier(data, numClasses, categoricalFeaturesInfo={}, :param maxBins: Number of bins used for finding splits at each node. :return: DecisionTreeModel """ - return DecisionTree.train(data, "classification", numClasses, - categoricalFeaturesInfo, - impurity, maxDepth, maxBins) + sc = data.context + dataBytes = _get_unmangled_labeled_point_rdd(data) + categoricalFeaturesInfoJMap = \ + MapConverter().convert(categoricalFeaturesInfo, + sc._gateway._gateway_client) + model = sc._jvm.PythonMLLibAPI().trainDecisionTreeModel( + dataBytes._jrdd, "classification", + numClasses, categoricalFeaturesInfoJMap, + impurity, maxDepth, maxBins) + dataBytes.unpersist() + return DecisionTreeModel(sc, model) @staticmethod - def trainRegressor(data, categoricalFeaturesInfo={}, + def trainRegressor(data, categoricalFeaturesInfo, impurity="variance", maxDepth=4, maxBins=100): """ Train a DecisionTreeModel for regression. @@ -173,42 +181,14 @@ def trainRegressor(data, categoricalFeaturesInfo={}, :param maxBins: Number of bins used for finding splits at each node. :return: DecisionTreeModel """ - return DecisionTree.train(data, "regression", 0, - categoricalFeaturesInfo, - impurity, maxDepth, maxBins) - - @staticmethod - def train(data, algo, numClasses, categoricalFeaturesInfo, - impurity, maxDepth, maxBins=100): - """ - Train a DecisionTreeModel for classification or regression. - - :param data: Training data: RDD of LabeledPoint. - For classification, labels are integers - {0,1,...,numClasses}. - For regression, labels are real numbers. - :param algo: "classification" or "regression" - :param numClasses: Number of classes for classification. - :param categoricalFeaturesInfo: Map from categorical feature index - to number of categories. - Any feature not in this map - is treated as continuous. - :param impurity: For classification: "entropy" or "gini". - For regression: "variance". - :param maxDepth: Max depth of tree. - E.g., depth 0 means 1 leaf node. - Depth 1 means 1 internal node + 2 leaf nodes. - :param maxBins: Number of bins used for finding splits at each node. - :return: DecisionTreeModel - """ sc = data.context dataBytes = _get_unmangled_labeled_point_rdd(data) categoricalFeaturesInfoJMap = \ MapConverter().convert(categoricalFeaturesInfo, sc._gateway._gateway_client) model = sc._jvm.PythonMLLibAPI().trainDecisionTreeModel( - dataBytes._jrdd, algo, - numClasses, categoricalFeaturesInfoJMap, + dataBytes._jrdd, "regression", + 0, categoricalFeaturesInfoJMap, impurity, maxDepth, maxBins) dataBytes.unpersist() return DecisionTreeModel(sc, model) From d6cd6fd03b25c08582bc2d5a3654676154694ddf Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Thu, 7 Aug 2014 00:04:18 -0700 Subject: [PATCH 071/231] SPARK-2879 part 2 [BUILD] Use HTTPS to access Maven Central and other repos .. and use canonical repo1.maven.org Maven Central repo. (And make sure snapshots are disabled for plugins from Maven Central.) Author: Sean Owen Closes #1828 from srowen/SPARK-2879.2 and squashes the following commits: 639f495 [Sean Owen] .. and use canonical repo1.maven.org Maven Central repo. (And make sure snapshots are disabled for plugins from Maven Central.) (cherry picked from commit 75993a65173172da32bbe98751e8c0f55c17a52e) Signed-off-by: Patrick Wendell --- pom.xml | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/pom.xml b/pom.xml index dcda3d53b5cb2..c87f776bda659 100644 --- a/pom.xml +++ b/pom.xml @@ -145,8 +145,7 @@ central Maven Repository - - https://repo.maven.apache.org/maven2 + https://repo1.maven.org/maven2 true @@ -228,6 +227,9 @@ true + + false + From c0894291bfbf041c98cf66004617712d085e8750 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Thu, 7 Aug 2014 00:20:38 -0700 Subject: [PATCH 072/231] [mllib] DecisionTree Strategy parameter checks Added some checks to Strategy to print out meaningful error messages when given invalid DecisionTree parameters. CC mengxr Author: Joseph K. Bradley Closes #1821 from jkbradley/dt-robustness and squashes the following commits: 4dc449a [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-robustness 7a61f7b [Joseph K. Bradley] Added some checks to Strategy to print out meaningful error messages when given invalid DecisionTree parameters (cherry picked from commit 8d1dec4fa4798bb48b8947446d306ec9ba6bddb5) Signed-off-by: Xiangrui Meng --- .../spark/mllib/tree/DecisionTree.scala | 10 ++++-- .../mllib/tree/configuration/Strategy.scala | 31 ++++++++++++++++++- 2 files changed, 38 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index c8a865659682f..bb50f07be5d7b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -44,6 +44,8 @@ import org.apache.spark.util.random.XORShiftRandom @Experimental class DecisionTree (private val strategy: Strategy) extends Serializable with Logging { + strategy.assertValid() + /** * Method to train a decision tree model over an RDD * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] @@ -1465,10 +1467,14 @@ object DecisionTree extends Serializable with Logging { /* - * Ensure #bins is always greater than the categories. For multiclass classification, - * #bins should be greater than 2^(maxCategories - 1) - 1. + * Ensure numBins is always greater than the categories. For multiclass classification, + * numBins should be greater than 2^(maxCategories - 1) - 1. * It's a limitation of the current implementation but a reasonable trade-off since features * with large number of categories get favored over continuous features. + * + * This needs to be checked here instead of in Strategy since numBins can be determined + * by the number of training examples. + * TODO: Allow this case, where we simply will know nothing about some categories. */ if (strategy.categoricalFeaturesInfo.size > 0) { val maxCategoriesForFeatures = strategy.categoricalFeaturesInfo.maxBy(_._2)._2 diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index 4ee4bcd0bcbc7..f31a503608b22 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -20,7 +20,7 @@ package org.apache.spark.mllib.tree.configuration import scala.collection.JavaConverters._ import org.apache.spark.annotation.Experimental -import org.apache.spark.mllib.tree.impurity.Impurity +import org.apache.spark.mllib.tree.impurity.{Variance, Entropy, Gini, Impurity} import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ @@ -90,4 +90,33 @@ class Strategy ( categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap) } + private[tree] def assertValid(): Unit = { + algo match { + case Classification => + require(numClassesForClassification >= 2, + s"DecisionTree Strategy for Classification must have numClassesForClassification >= 2," + + s" but numClassesForClassification = $numClassesForClassification.") + require(Set(Gini, Entropy).contains(impurity), + s"DecisionTree Strategy given invalid impurity for Classification: $impurity." + + s" Valid settings: Gini, Entropy") + case Regression => + require(impurity == Variance, + s"DecisionTree Strategy given invalid impurity for Regression: $impurity." + + s" Valid settings: Variance") + case _ => + throw new IllegalArgumentException( + s"DecisionTree Strategy given invalid algo parameter: $algo." + + s" Valid settings are: Classification, Regression.") + } + require(maxDepth >= 0, s"DecisionTree Strategy given invalid maxDepth parameter: $maxDepth." + + s" Valid values are integers >= 0.") + require(maxBins >= 2, s"DecisionTree Strategy given invalid maxBins parameter: $maxBins." + + s" Valid values are integers >= 2.") + categoricalFeaturesInfo.foreach { case (feature, arity) => + require(arity >= 2, + s"DecisionTree Strategy given invalid categoricalFeaturesInfo setting:" + + s" feature $feature has $arity categories. The number of categories should be >= 2.") + } + } + } From f705c1d5664b137fbd03a286c86d7c543c73ebe8 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Thu, 7 Aug 2014 11:28:12 -0700 Subject: [PATCH 073/231] [SPARK-2852][MLLIB] Separate model from IDF/StandardScaler algorithms This is part of SPARK-2828: 1. separate IDF model from IDF algorithm (which generates a model) 2. separate StandardScaler model from StandardScaler CC: dbtsai Author: Xiangrui Meng Closes #1814 from mengxr/feature-api-update and squashes the following commits: 40d863b [Xiangrui Meng] move mean and variance to model 48a0fff [Xiangrui Meng] separate Model from StandardScaler algorithm 89f3486 [Xiangrui Meng] update IDF to separate Model from Algorithm (cherry picked from commit b9e9e53773a618e4322b845c40deae22f2ba52ac) Signed-off-by: Xiangrui Meng --- .../org/apache/spark/mllib/feature/IDF.scala | 130 ++++++++---------- .../spark/mllib/feature/StandardScaler.scala | 58 ++++---- .../apache/spark/mllib/feature/IDFSuite.scala | 12 +- .../mllib/feature/StandardScalerSuite.scala | 50 +++---- 4 files changed, 121 insertions(+), 129 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala index 7ed611a857acc..d40d5553c1d21 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala @@ -36,87 +36,25 @@ class IDF { // TODO: Allow different IDF formulations. - private var brzIdf: BDV[Double] = _ - /** * Computes the inverse document frequency. * @param dataset an RDD of term frequency vectors */ - def fit(dataset: RDD[Vector]): this.type = { - brzIdf = dataset.treeAggregate(new IDF.DocumentFrequencyAggregator)( + def fit(dataset: RDD[Vector]): IDFModel = { + val idf = dataset.treeAggregate(new IDF.DocumentFrequencyAggregator)( seqOp = (df, v) => df.add(v), combOp = (df1, df2) => df1.merge(df2) ).idf() - this + new IDFModel(idf) } /** * Computes the inverse document frequency. * @param dataset a JavaRDD of term frequency vectors */ - def fit(dataset: JavaRDD[Vector]): this.type = { + def fit(dataset: JavaRDD[Vector]): IDFModel = { fit(dataset.rdd) } - - /** - * Transforms term frequency (TF) vectors to TF-IDF vectors. - * @param dataset an RDD of term frequency vectors - * @return an RDD of TF-IDF vectors - */ - def transform(dataset: RDD[Vector]): RDD[Vector] = { - if (!initialized) { - throw new IllegalStateException("Haven't learned IDF yet. Call fit first.") - } - val theIdf = brzIdf - val bcIdf = dataset.context.broadcast(theIdf) - dataset.mapPartitions { iter => - val thisIdf = bcIdf.value - iter.map { v => - val n = v.size - v match { - case sv: SparseVector => - val nnz = sv.indices.size - val newValues = new Array[Double](nnz) - var k = 0 - while (k < nnz) { - newValues(k) = sv.values(k) * thisIdf(sv.indices(k)) - k += 1 - } - Vectors.sparse(n, sv.indices, newValues) - case dv: DenseVector => - val newValues = new Array[Double](n) - var j = 0 - while (j < n) { - newValues(j) = dv.values(j) * thisIdf(j) - j += 1 - } - Vectors.dense(newValues) - case other => - throw new UnsupportedOperationException( - s"Only sparse and dense vectors are supported but got ${other.getClass}.") - } - } - } - } - - /** - * Transforms term frequency (TF) vectors to TF-IDF vectors (Java version). - * @param dataset a JavaRDD of term frequency vectors - * @return a JavaRDD of TF-IDF vectors - */ - def transform(dataset: JavaRDD[Vector]): JavaRDD[Vector] = { - transform(dataset.rdd).toJavaRDD() - } - - /** Returns the IDF vector. */ - def idf(): Vector = { - if (!initialized) { - throw new IllegalStateException("Haven't learned IDF yet. Call fit first.") - } - Vectors.fromBreeze(brzIdf) - } - - private def initialized: Boolean = brzIdf != null } private object IDF { @@ -177,18 +115,72 @@ private object IDF { private def isEmpty: Boolean = m == 0L /** Returns the current IDF vector. */ - def idf(): BDV[Double] = { + def idf(): Vector = { if (isEmpty) { throw new IllegalStateException("Haven't seen any document yet.") } val n = df.length - val inv = BDV.zeros[Double](n) + val inv = new Array[Double](n) var j = 0 while (j < n) { inv(j) = math.log((m + 1.0)/ (df(j) + 1.0)) j += 1 } - inv + Vectors.dense(inv) } } } + +/** + * :: Experimental :: + * Represents an IDF model that can transform term frequency vectors. + */ +@Experimental +class IDFModel private[mllib] (val idf: Vector) extends Serializable { + + /** + * Transforms term frequency (TF) vectors to TF-IDF vectors. + * @param dataset an RDD of term frequency vectors + * @return an RDD of TF-IDF vectors + */ + def transform(dataset: RDD[Vector]): RDD[Vector] = { + val bcIdf = dataset.context.broadcast(idf) + dataset.mapPartitions { iter => + val thisIdf = bcIdf.value + iter.map { v => + val n = v.size + v match { + case sv: SparseVector => + val nnz = sv.indices.size + val newValues = new Array[Double](nnz) + var k = 0 + while (k < nnz) { + newValues(k) = sv.values(k) * thisIdf(sv.indices(k)) + k += 1 + } + Vectors.sparse(n, sv.indices, newValues) + case dv: DenseVector => + val newValues = new Array[Double](n) + var j = 0 + while (j < n) { + newValues(j) = dv.values(j) * thisIdf(j) + j += 1 + } + Vectors.dense(newValues) + case other => + throw new UnsupportedOperationException( + s"Only sparse and dense vectors are supported but got ${other.getClass}.") + } + } + } + } + + /** + * Transforms term frequency (TF) vectors to TF-IDF vectors (Java version). + * @param dataset a JavaRDD of term frequency vectors + * @return a JavaRDD of TF-IDF vectors + */ + def transform(dataset: JavaRDD[Vector]): JavaRDD[Vector] = { + transform(dataset.rdd).toJavaRDD() + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala index e6c9f8f67df63..4dfd1f0ab8134 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala @@ -17,8 +17,9 @@ package org.apache.spark.mllib.feature -import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, Vector => BV} +import breeze.linalg.{DenseVector => BDV, SparseVector => BSV} +import org.apache.spark.Logging import org.apache.spark.annotation.Experimental import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.rdd.RDDFunctions._ @@ -35,37 +36,55 @@ import org.apache.spark.rdd.RDD * @param withStd True by default. Scales the data to unit standard deviation. */ @Experimental -class StandardScaler(withMean: Boolean, withStd: Boolean) extends VectorTransformer { +class StandardScaler(withMean: Boolean, withStd: Boolean) extends Logging { def this() = this(false, true) - require(withMean || withStd, s"withMean and withStd both equal to false. Doing nothing.") - - private var mean: BV[Double] = _ - private var factor: BV[Double] = _ + if (!(withMean || withStd)) { + logWarning("Both withMean and withStd are false. The model does nothing.") + } /** * Computes the mean and variance and stores as a model to be used for later scaling. * * @param data The data used to compute the mean and variance to build the transformation model. - * @return This StandardScalar object. + * @return a StandardScalarModel */ - def fit(data: RDD[Vector]): this.type = { + def fit(data: RDD[Vector]): StandardScalerModel = { + // TODO: skip computation if both withMean and withStd are false val summary = data.treeAggregate(new MultivariateOnlineSummarizer)( (aggregator, data) => aggregator.add(data), (aggregator1, aggregator2) => aggregator1.merge(aggregator2)) + new StandardScalerModel(withMean, withStd, summary.mean, summary.variance) + } +} - mean = summary.mean.toBreeze - factor = summary.variance.toBreeze - require(mean.length == factor.length) +/** + * :: Experimental :: + * Represents a StandardScaler model that can transform vectors. + * + * @param withMean whether to center the data before scaling + * @param withStd whether to scale the data to have unit standard deviation + * @param mean column mean values + * @param variance column variance values + */ +@Experimental +class StandardScalerModel private[mllib] ( + val withMean: Boolean, + val withStd: Boolean, + val mean: Vector, + val variance: Vector) extends VectorTransformer { + + require(mean.size == variance.size) + private lazy val factor: BDV[Double] = { + val f = BDV.zeros[Double](variance.size) var i = 0 - while (i < factor.length) { - factor(i) = if (factor(i) != 0.0) 1.0 / math.sqrt(factor(i)) else 0.0 + while (i < f.size) { + f(i) = if (variance(i) != 0.0) 1.0 / math.sqrt(variance(i)) else 0.0 i += 1 } - - this + f } /** @@ -76,13 +95,7 @@ class StandardScaler(withMean: Boolean, withStd: Boolean) extends VectorTransfor * for the column with zero variance. */ override def transform(vector: Vector): Vector = { - if (mean == null || factor == null) { - throw new IllegalStateException( - "Haven't learned column summary statistics yet. Call fit first.") - } - - require(vector.size == mean.length) - + require(mean.size == vector.size) if (withMean) { vector.toBreeze match { case dv: BDV[Double] => @@ -115,5 +128,4 @@ class StandardScaler(withMean: Boolean, withStd: Boolean) extends VectorTransfor vector } } - } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/IDFSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/IDFSuite.scala index 78a2804ff204b..53d9c0c640b98 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/IDFSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/IDFSuite.scala @@ -36,18 +36,12 @@ class IDFSuite extends FunSuite with LocalSparkContext { val m = localTermFrequencies.size val termFrequencies = sc.parallelize(localTermFrequencies, 2) val idf = new IDF - intercept[IllegalStateException] { - idf.idf() - } - intercept[IllegalStateException] { - idf.transform(termFrequencies) - } - idf.fit(termFrequencies) + val model = idf.fit(termFrequencies) val expected = Vectors.dense(Array(0, 3, 1, 2).map { x => math.log((m.toDouble + 1.0) / (x + 1.0)) }) - assert(idf.idf() ~== expected absTol 1e-12) - val tfidf = idf.transform(termFrequencies).cache().zipWithIndex().map(_.swap).collectAsMap() + assert(model.idf ~== expected absTol 1e-12) + val tfidf = model.transform(termFrequencies).cache().zipWithIndex().map(_.swap).collectAsMap() assert(tfidf.size === 3) val tfidf0 = tfidf(0L).asInstanceOf[SparseVector] assert(tfidf0.indices === Array(1, 3)) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala index 5a9be923a8625..e217b93cebbdb 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala @@ -50,23 +50,17 @@ class StandardScalerSuite extends FunSuite with LocalSparkContext { val standardizer2 = new StandardScaler() val standardizer3 = new StandardScaler(withMean = true, withStd = false) - withClue("Using a standardizer before fitting the model should throw exception.") { - intercept[IllegalStateException] { - data.map(standardizer1.transform) - } - } - - standardizer1.fit(dataRDD) - standardizer2.fit(dataRDD) - standardizer3.fit(dataRDD) + val model1 = standardizer1.fit(dataRDD) + val model2 = standardizer2.fit(dataRDD) + val model3 = standardizer3.fit(dataRDD) - val data1 = data.map(standardizer1.transform) - val data2 = data.map(standardizer2.transform) - val data3 = data.map(standardizer3.transform) + val data1 = data.map(model1.transform) + val data2 = data.map(model2.transform) + val data3 = data.map(model3.transform) - val data1RDD = standardizer1.transform(dataRDD) - val data2RDD = standardizer2.transform(dataRDD) - val data3RDD = standardizer3.transform(dataRDD) + val data1RDD = model1.transform(dataRDD) + val data2RDD = model2.transform(dataRDD) + val data3RDD = model3.transform(dataRDD) val summary = computeSummary(dataRDD) val summary1 = computeSummary(data1RDD) @@ -129,25 +123,25 @@ class StandardScalerSuite extends FunSuite with LocalSparkContext { val standardizer2 = new StandardScaler() val standardizer3 = new StandardScaler(withMean = true, withStd = false) - standardizer1.fit(dataRDD) - standardizer2.fit(dataRDD) - standardizer3.fit(dataRDD) + val model1 = standardizer1.fit(dataRDD) + val model2 = standardizer2.fit(dataRDD) + val model3 = standardizer3.fit(dataRDD) - val data2 = data.map(standardizer2.transform) + val data2 = data.map(model2.transform) withClue("Standardization with mean can not be applied on sparse input.") { intercept[IllegalArgumentException] { - data.map(standardizer1.transform) + data.map(model1.transform) } } withClue("Standardization with mean can not be applied on sparse input.") { intercept[IllegalArgumentException] { - data.map(standardizer3.transform) + data.map(model3.transform) } } - val data2RDD = standardizer2.transform(dataRDD) + val data2RDD = model2.transform(dataRDD) val summary2 = computeSummary(data2RDD) @@ -181,13 +175,13 @@ class StandardScalerSuite extends FunSuite with LocalSparkContext { val standardizer2 = new StandardScaler(withMean = true, withStd = false) val standardizer3 = new StandardScaler(withMean = false, withStd = true) - standardizer1.fit(dataRDD) - standardizer2.fit(dataRDD) - standardizer3.fit(dataRDD) + val model1 = standardizer1.fit(dataRDD) + val model2 = standardizer2.fit(dataRDD) + val model3 = standardizer3.fit(dataRDD) - val data1 = data.map(standardizer1.transform) - val data2 = data.map(standardizer2.transform) - val data3 = data.map(standardizer3.transform) + val data1 = data.map(model1.transform) + val data2 = data.map(model2.transform) + val data3 = data.map(model3.transform) assert(data1.forall(_.toArray.forall(_ == 0.0)), "The variance is zero, so the transformed result should be 0.0") From c65c810c83e352f5b7198ce74f8f5561617a55bd Mon Sep 17 00:00:00 2001 From: Oleg Danilov Date: Thu, 7 Aug 2014 15:48:44 -0700 Subject: [PATCH 074/231] SPARK-2905 Fixed path sbin => bin Author: Oleg Danilov Closes #1835 from dosoft/SPARK-2905 and squashes the following commits: 4df423c [Oleg Danilov] SPARK-2905 Fixed path sbin => bin (cherry picked from commit 80ec5bad1311651fe56e1d5178090dc63753233b) Signed-off-by: Patrick Wendell --- bin/spark-sql | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bin/spark-sql b/bin/spark-sql index 61ebd8ab6dec8..7813ccc361415 100755 --- a/bin/spark-sql +++ b/bin/spark-sql @@ -29,7 +29,7 @@ CLASS="org.apache.spark.sql.hive.thriftserver.SparkSQLCLIDriver" FWDIR="$(cd `dirname $0`/..; pwd)" function usage { - echo "Usage: ./sbin/spark-sql [options] [cli option]" + echo "Usage: ./bin/spark-sql [options] [cli option]" pattern="usage" pattern+="\|Spark assembly has been built with Hive" pattern+="\|NOTE: SPARK_PREPEND_CLASSES is set" From 30369b80636032839992bf4bce1d1961062f0058 Mon Sep 17 00:00:00 2001 From: Prashant Sharma Date: Thu, 7 Aug 2014 16:24:22 -0700 Subject: [PATCH 075/231] SPARK-2899 Doc generation is back to working in new SBT Build. The reason for this bug was introduciton of OldDeps project. It had to be excluded to prevent unidocs from trying to put it on "docs compile" classpath. Author: Prashant Sharma Closes #1830 from ScrapCodes/doc-fix and squashes the following commits: e5d52e6 [Prashant Sharma] SPARK-2899 Doc generation is back to working in new SBT Build. (cherry picked from commit 32096c2aed9978cfb9a904b4f56bb61800d17e9e) Signed-off-by: Patrick Wendell --- project/SparkBuild.scala | 60 ++++++++++++++++++++++------------------ project/plugins.sbt | 2 +- 2 files changed, 34 insertions(+), 28 deletions(-) diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index ed587783d5606..63a285b81a60c 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -30,11 +30,11 @@ object BuildCommons { private val buildLocation = file(".").getAbsoluteFile.getParentFile - val allProjects@Seq(bagel, catalyst, core, graphx, hive, hiveThriftServer, mllib, repl, spark, + val allProjects@Seq(bagel, catalyst, core, graphx, hive, hiveThriftServer, mllib, repl, sql, streaming, streamingFlumeSink, streamingFlume, streamingKafka, streamingMqtt, streamingTwitter, streamingZeromq) = Seq("bagel", "catalyst", "core", "graphx", "hive", "hive-thriftserver", "mllib", "repl", - "spark", "sql", "streaming", "streaming-flume-sink", "streaming-flume", "streaming-kafka", + "sql", "streaming", "streaming-flume-sink", "streaming-flume", "streaming-kafka", "streaming-mqtt", "streaming-twitter", "streaming-zeromq").map(ProjectRef(buildLocation, _)) val optionallyEnabledProjects@Seq(yarn, yarnStable, yarnAlpha, java8Tests, sparkGangliaLgpl, sparkKinesisAsl) = @@ -44,8 +44,9 @@ object BuildCommons { val assemblyProjects@Seq(assembly, examples) = Seq("assembly", "examples") .map(ProjectRef(buildLocation, _)) - val tools = "tools" - + val tools = ProjectRef(buildLocation, "tools") + // Root project. + val spark = ProjectRef(buildLocation, "spark") val sparkHome = buildLocation } @@ -126,26 +127,6 @@ object SparkBuild extends PomBuild { publishLocalBoth <<= Seq(publishLocal in MavenCompile, publishLocal).dependOn ) - /** Following project only exists to pull previous artifacts of Spark for generating - Mima ignores. For more information see: SPARK 2071 */ - lazy val oldDeps = Project("oldDeps", file("dev"), settings = oldDepsSettings) - - def versionArtifact(id: String): Option[sbt.ModuleID] = { - val fullId = id + "_2.10" - Some("org.apache.spark" % fullId % "1.0.0") - } - - def oldDepsSettings() = Defaults.defaultSettings ++ Seq( - name := "old-deps", - scalaVersion := "2.10.4", - retrieveManaged := true, - retrievePattern := "[type]s/[artifact](-[revision])(-[classifier]).[ext]", - libraryDependencies := Seq("spark-streaming-mqtt", "spark-streaming-zeromq", - "spark-streaming-flume", "spark-streaming-kafka", "spark-streaming-twitter", - "spark-streaming", "spark-mllib", "spark-bagel", "spark-graphx", - "spark-core").map(versionArtifact(_).get intransitive()) - ) - def enable(settings: Seq[Setting[_]])(projectRef: ProjectRef) = { val existingSettings = projectsMap.getOrElse(projectRef.project, Seq[Setting[_]]()) projectsMap += (projectRef.project -> (existingSettings ++ settings)) @@ -184,7 +165,7 @@ object SparkBuild extends PomBuild { super.projectDefinitions(baseDirectory).map { x => if (projectsMap.exists(_._1 == x.id)) x.settings(projectsMap(x.id): _*) else x.settings(Seq[Setting[_]](): _*) - } ++ Seq[Project](oldDeps) + } ++ Seq[Project](OldDeps.project) } } @@ -193,6 +174,31 @@ object Flume { lazy val settings = sbtavro.SbtAvro.avroSettings } +/** + * Following project only exists to pull previous artifacts of Spark for generating + * Mima ignores. For more information see: SPARK 2071 + */ +object OldDeps { + + lazy val project = Project("oldDeps", file("dev"), settings = oldDepsSettings) + + def versionArtifact(id: String): Option[sbt.ModuleID] = { + val fullId = id + "_2.10" + Some("org.apache.spark" % fullId % "1.0.0") + } + + def oldDepsSettings() = Defaults.defaultSettings ++ Seq( + name := "old-deps", + scalaVersion := "2.10.4", + retrieveManaged := true, + retrievePattern := "[type]s/[artifact](-[revision])(-[classifier]).[ext]", + libraryDependencies := Seq("spark-streaming-mqtt", "spark-streaming-zeromq", + "spark-streaming-flume", "spark-streaming-kafka", "spark-streaming-twitter", + "spark-streaming", "spark-mllib", "spark-bagel", "spark-graphx", + "spark-core").map(versionArtifact(_).get intransitive()) + ) +} + object Catalyst { lazy val settings = Seq( addCompilerPlugin("org.scalamacros" % "paradise" % "2.0.1" cross CrossVersion.full), @@ -285,9 +291,9 @@ object Unidoc { publish := {}, unidocProjectFilter in(ScalaUnidoc, unidoc) := - inAnyProject -- inProjects(repl, examples, tools, catalyst, yarn, yarnAlpha), + inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, catalyst, yarn, yarnAlpha), unidocProjectFilter in(JavaUnidoc, unidoc) := - inAnyProject -- inProjects(repl, bagel, graphx, examples, tools, catalyst, yarn, yarnAlpha), + inAnyProject -- inProjects(OldDeps.project, repl, bagel, graphx, examples, tools, catalyst, yarn, yarnAlpha), // Skip class names containing $ and some internal packages in Javadocs unidocAllSources in (JavaUnidoc, unidoc) := { diff --git a/project/plugins.sbt b/project/plugins.sbt index 06d18e193076e..2a61f56c2ea60 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -23,6 +23,6 @@ addSbtPlugin("com.typesafe" % "sbt-mima-plugin" % "0.1.6") addSbtPlugin("com.alpinenow" % "junit_xml_listener" % "0.5.1") -addSbtPlugin("com.eed3si9n" % "sbt-unidoc" % "0.3.0") +addSbtPlugin("com.eed3si9n" % "sbt-unidoc" % "0.3.1") addSbtPlugin("com.cavorite" % "sbt-avro" % "0.3.2") From 0f2274f8ed6131ad17326e3fff7f7e093863b72d Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Thu, 7 Aug 2014 18:04:49 -0700 Subject: [PATCH 076/231] SPARK-2787: Make sort-based shuffle write files directly when there's no sorting/aggregation and # partitions is small As described in https://issues.apache.org/jira/browse/SPARK-2787, right now sort-based shuffle is more expensive than hash-based for map operations that do no partial aggregation or sorting, such as groupByKey. This is because it has to serialize each data item twice (once when spilling to intermediate files, and then again when merging these files object-by-object). This patch adds a code path to just write separate files directly if the # of output partitions is small, and concatenate them at the end to produce a sorted file. On the unit test side, I added some tests that force or don't force this bypass path to be used, and checked that our tests for other features (e.g. all the operations) cover both cases. Author: Matei Zaharia Closes #1799 from mateiz/SPARK-2787 and squashes the following commits: 88cf26a [Matei Zaharia] Fix rebase 10233af [Matei Zaharia] Review comments 398cb95 [Matei Zaharia] Fix looking up shuffle manager in conf ca3efd9 [Matei Zaharia] Add docs for shuffle manager properties, and allow short names for them d0ae3c5 [Matei Zaharia] Fix some comments 90d084f [Matei Zaharia] Add code path to bypass merge-sort in ExternalSorter, and tests 31e5d7c [Matei Zaharia] Move existing logic for writing partitioned files into ExternalSorter (cherry picked from commit 6906b69cf568015f20c7d7c77cbcba650e5431a9) Signed-off-by: Reynold Xin --- .../scala/org/apache/spark/SparkEnv.scala | 27 +- .../shuffle/hash/HashShuffleReader.scala | 2 +- .../shuffle/sort/SortShuffleWriter.scala | 80 ++---- .../util/collection/ExternalSorter.scala | 233 +++++++++++++++--- .../util/collection/ExternalSorterSuite.scala | 165 +++++++++++-- docs/configuration.md | 18 ++ 6 files changed, 407 insertions(+), 118 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 9d4edeb6d96cf..22d8d1cb1ddcf 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -156,11 +156,9 @@ object SparkEnv extends Logging { conf.set("spark.driver.port", boundPort.toString) } - // Create an instance of the class named by the given Java system property, or by - // defaultClassName if the property is not set, and return it as a T - def instantiateClass[T](propertyName: String, defaultClassName: String): T = { - val name = conf.get(propertyName, defaultClassName) - val cls = Class.forName(name, true, Utils.getContextOrSparkClassLoader) + // Create an instance of the class with the given name, possibly initializing it with our conf + def instantiateClass[T](className: String): T = { + val cls = Class.forName(className, true, Utils.getContextOrSparkClassLoader) // Look for a constructor taking a SparkConf and a boolean isDriver, then one taking just // SparkConf, then one taking no arguments try { @@ -178,11 +176,17 @@ object SparkEnv extends Logging { } } - val serializer = instantiateClass[Serializer]( + // Create an instance of the class named by the given SparkConf property, or defaultClassName + // if the property is not set, possibly initializing it with our conf + def instantiateClassFromConf[T](propertyName: String, defaultClassName: String): T = { + instantiateClass[T](conf.get(propertyName, defaultClassName)) + } + + val serializer = instantiateClassFromConf[Serializer]( "spark.serializer", "org.apache.spark.serializer.JavaSerializer") logDebug(s"Using serializer: ${serializer.getClass}") - val closureSerializer = instantiateClass[Serializer]( + val closureSerializer = instantiateClassFromConf[Serializer]( "spark.closure.serializer", "org.apache.spark.serializer.JavaSerializer") def registerOrLookup(name: String, newActor: => Actor): ActorRef = { @@ -246,8 +250,13 @@ object SparkEnv extends Logging { "." } - val shuffleManager = instantiateClass[ShuffleManager]( - "spark.shuffle.manager", "org.apache.spark.shuffle.hash.HashShuffleManager") + // Let the user specify short names for shuffle managers + val shortShuffleMgrNames = Map( + "hash" -> "org.apache.spark.shuffle.hash.HashShuffleManager", + "sort" -> "org.apache.spark.shuffle.sort.SortShuffleManager") + val shuffleMgrName = conf.get("spark.shuffle.manager", "hash") + val shuffleMgrClass = shortShuffleMgrNames.getOrElse(shuffleMgrName.toLowerCase, shuffleMgrName) + val shuffleManager = instantiateClass[ShuffleManager](shuffleMgrClass) val shuffleMemoryManager = new ShuffleMemoryManager(conf) diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala index 7c9dc8e5f88ef..88a5f1e5ddf58 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala @@ -58,7 +58,7 @@ private[spark] class HashShuffleReader[K, C]( // Create an ExternalSorter to sort the data. Note that if spark.shuffle.spill is disabled, // the ExternalSorter won't spill to disk. val sorter = new ExternalSorter[K, C, C](ordering = Some(keyOrd), serializer = Some(ser)) - sorter.write(aggregatedIter) + sorter.insertAll(aggregatedIter) context.taskMetrics.memoryBytesSpilled += sorter.memoryBytesSpilled context.taskMetrics.diskBytesSpilled += sorter.diskBytesSpilled sorter.iterator diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala index e54e6383d2ccc..22f656fa371ea 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala @@ -44,6 +44,7 @@ private[spark] class SortShuffleWriter[K, V, C]( private var sorter: ExternalSorter[K, V, _] = null private var outputFile: File = null + private var indexFile: File = null // Are we in the process of stopping? Because map tasks can call stop() with success = true // and then call stop() with success = false if they get an exception, we want to make sure @@ -57,78 +58,36 @@ private[spark] class SortShuffleWriter[K, V, C]( /** Write a bunch of records to this task's output */ override def write(records: Iterator[_ <: Product2[K, V]]): Unit = { - // Get an iterator with the elements for each partition ID - val partitions: Iterator[(Int, Iterator[Product2[K, _]])] = { - if (dep.mapSideCombine) { - if (!dep.aggregator.isDefined) { - throw new IllegalStateException("Aggregator is empty for map-side combine") - } - sorter = new ExternalSorter[K, V, C]( - dep.aggregator, Some(dep.partitioner), dep.keyOrdering, dep.serializer) - sorter.write(records) - sorter.partitionedIterator - } else { - // In this case we pass neither an aggregator nor an ordering to the sorter, because we - // don't care whether the keys get sorted in each partition; that will be done on the - // reduce side if the operation being run is sortByKey. - sorter = new ExternalSorter[K, V, V]( - None, Some(dep.partitioner), None, dep.serializer) - sorter.write(records) - sorter.partitionedIterator + if (dep.mapSideCombine) { + if (!dep.aggregator.isDefined) { + throw new IllegalStateException("Aggregator is empty for map-side combine") } + sorter = new ExternalSorter[K, V, C]( + dep.aggregator, Some(dep.partitioner), dep.keyOrdering, dep.serializer) + sorter.insertAll(records) + } else { + // In this case we pass neither an aggregator nor an ordering to the sorter, because we don't + // care whether the keys get sorted in each partition; that will be done on the reduce side + // if the operation being run is sortByKey. + sorter = new ExternalSorter[K, V, V]( + None, Some(dep.partitioner), None, dep.serializer) + sorter.insertAll(records) } // Create a single shuffle file with reduce ID 0 that we'll write all results to. We'll later // serve different ranges of this file using an index file that we create at the end. val blockId = ShuffleBlockId(dep.shuffleId, mapId, 0) - outputFile = blockManager.diskBlockManager.getFile(blockId) - - // Track location of each range in the output file - val offsets = new Array[Long](numPartitions + 1) - val lengths = new Array[Long](numPartitions) - - for ((id, elements) <- partitions) { - if (elements.hasNext) { - val writer = blockManager.getDiskWriter(blockId, outputFile, ser, fileBufferSize, - writeMetrics) - for (elem <- elements) { - writer.write(elem) - } - writer.commitAndClose() - val segment = writer.fileSegment() - offsets(id + 1) = segment.offset + segment.length - lengths(id) = segment.length - } else { - // The partition is empty; don't create a new writer to avoid writing headers, etc - offsets(id + 1) = offsets(id) - } - } - - context.taskMetrics.memoryBytesSpilled += sorter.memoryBytesSpilled - context.taskMetrics.diskBytesSpilled += sorter.diskBytesSpilled - // Write an index file with the offsets of each block, plus a final offset at the end for the - // end of the output file. This will be used by SortShuffleManager.getBlockLocation to figure - // out where each block begins and ends. + outputFile = blockManager.diskBlockManager.getFile(blockId) + indexFile = blockManager.diskBlockManager.getFile(blockId.name + ".index") - val diskBlockManager = blockManager.diskBlockManager - val indexFile = diskBlockManager.getFile(blockId.name + ".index") - val out = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(indexFile))) - try { - var i = 0 - while (i < numPartitions + 1) { - out.writeLong(offsets(i)) - i += 1 - } - } finally { - out.close() - } + val partitionLengths = sorter.writePartitionedFile(blockId, context) // Register our map output with the ShuffleBlockManager, which handles cleaning it over time blockManager.shuffleBlockManager.addCompletedMap(dep.shuffleId, mapId, numPartitions) mapStatus = new MapStatus(blockManager.blockManagerId, - lengths.map(MapOutputTracker.compressSize)) + partitionLengths.map(MapOutputTracker.compressSize)) } /** Close this writer, passing along whether the map completed */ @@ -145,6 +104,9 @@ private[spark] class SortShuffleWriter[K, V, C]( if (outputFile != null) { outputFile.delete() } + if (indexFile != null) { + indexFile.delete() + } return None } } finally { diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index eb4849ebc6e52..b73d5e0cf1714 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -25,10 +25,10 @@ import scala.collection.mutable import com.google.common.io.ByteStreams -import org.apache.spark.{Aggregator, SparkEnv, Logging, Partitioner} +import org.apache.spark._ import org.apache.spark.serializer.{DeserializationStream, Serializer} -import org.apache.spark.storage.BlockId import org.apache.spark.executor.ShuffleWriteMetrics +import org.apache.spark.storage.{BlockObjectWriter, BlockId} /** * Sorts and potentially merges a number of key-value pairs of type (K, V) to produce key-combiner @@ -67,6 +67,13 @@ import org.apache.spark.executor.ShuffleWriteMetrics * for equality to merge values. * * - Users are expected to call stop() at the end to delete all the intermediate files. + * + * As a special case, if no Ordering and no Aggregator is given, and the number of partitions is + * less than spark.shuffle.sort.bypassMergeThreshold, we bypass the merge-sort and just write to + * separate files for each partition each time we spill, similar to the HashShuffleWriter. We can + * then concatenate these files to produce a single sorted file, without having to serialize and + * de-serialize each item twice (as is needed during the merge). This speeds up the map side of + * groupBy, sort, etc operations since they do no partial aggregation. */ private[spark] class ExternalSorter[K, V, C]( aggregator: Option[Aggregator[K, V, C]] = None, @@ -124,6 +131,18 @@ private[spark] class ExternalSorter[K, V, C]( // How much of the shared memory pool this collection has claimed private var myMemoryThreshold = 0L + // If there are fewer than spark.shuffle.sort.bypassMergeThreshold partitions and we don't need + // local aggregation and sorting, write numPartitions files directly and just concatenate them + // at the end. This avoids doing serialization and deserialization twice to merge together the + // spilled files, which would happen with the normal code path. The downside is having multiple + // files open at a time and thus more memory allocated to buffers. + private val bypassMergeThreshold = conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200) + private val bypassMergeSort = + (numPartitions <= bypassMergeThreshold && aggregator.isEmpty && ordering.isEmpty) + + // Array of file writers for each partition, used if bypassMergeSort is true and we've spilled + private var partitionWriters: Array[BlockObjectWriter] = null + // A comparator for keys K that orders them within a partition to allow aggregation or sorting. // Can be a partial ordering by hash code if a total ordering is not provided through by the // user. (A partial ordering means that equal keys have comparator.compare(k, k) = 0, but some @@ -137,7 +156,14 @@ private[spark] class ExternalSorter[K, V, C]( } }) - // A comparator for (Int, K) elements that orders them by partition and then possibly by key + // A comparator for (Int, K) pairs that orders them by only their partition ID + private val partitionComparator: Comparator[(Int, K)] = new Comparator[(Int, K)] { + override def compare(a: (Int, K), b: (Int, K)): Int = { + a._1 - b._1 + } + } + + // A comparator that orders (Int, K) pairs by partition ID and then possibly by key private val partitionKeyComparator: Comparator[(Int, K)] = { if (ordering.isDefined || aggregator.isDefined) { // Sort by partition ID then key comparator @@ -153,11 +179,7 @@ private[spark] class ExternalSorter[K, V, C]( } } else { // Just sort it by partition ID - new Comparator[(Int, K)] { - override def compare(a: (Int, K), b: (Int, K)): Int = { - a._1 - b._1 - } - } + partitionComparator } } @@ -171,7 +193,7 @@ private[spark] class ExternalSorter[K, V, C]( elementsPerPartition: Array[Long]) private val spills = new ArrayBuffer[SpilledFile] - def write(records: Iterator[_ <: Product2[K, V]]): Unit = { + def insertAll(records: Iterator[_ <: Product2[K, V]]): Unit = { // TODO: stop combining if we find that the reduction factor isn't high val shouldCombine = aggregator.isDefined @@ -242,6 +264,38 @@ private[spark] class ExternalSorter[K, V, C]( val threadId = Thread.currentThread().getId logInfo("Thread %d spilling in-memory batch of %d MB to disk (%d spill%s so far)" .format(threadId, memorySize / (1024 * 1024), spillCount, if (spillCount > 1) "s" else "")) + + if (bypassMergeSort) { + spillToPartitionFiles(collection) + } else { + spillToMergeableFile(collection) + } + + if (usingMap) { + map = new SizeTrackingAppendOnlyMap[(Int, K), C] + } else { + buffer = new SizeTrackingPairBuffer[(Int, K), C] + } + + // Release our memory back to the shuffle pool so that other threads can grab it + shuffleMemoryManager.release(myMemoryThreshold) + myMemoryThreshold = 0 + + _memoryBytesSpilled += memorySize + } + + /** + * Spill our in-memory collection to a sorted file that we can merge later (normal code path). + * We add this file into spilledFiles to find it later. + * + * Alternatively, if bypassMergeSort is true, we spill to separate files for each partition. + * See spillToPartitionedFiles() for that code path. + * + * @param collection whichever collection we're using (map or buffer) + */ + private def spillToMergeableFile(collection: SizeTrackingPairCollection[(Int, K), C]): Unit = { + assert(!bypassMergeSort) + val (blockId, file) = diskBlockManager.createTempBlock() curWriteMetrics = new ShuffleWriteMetrics() var writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, curWriteMetrics) @@ -304,18 +358,36 @@ private[spark] class ExternalSorter[K, V, C]( } } - if (usingMap) { - map = new SizeTrackingAppendOnlyMap[(Int, K), C] - } else { - buffer = new SizeTrackingPairBuffer[(Int, K), C] - } + spills.append(SpilledFile(file, blockId, batchSizes.toArray, elementsPerPartition)) + } - // Release our memory back to the shuffle pool so that other threads can grab it - shuffleMemoryManager.release(myMemoryThreshold) - myMemoryThreshold = 0 + /** + * Spill our in-memory collection to separate files, one for each partition. This is used when + * there's no aggregator and ordering and the number of partitions is small, because it allows + * writePartitionedFile to just concatenate files without deserializing data. + * + * @param collection whichever collection we're using (map or buffer) + */ + private def spillToPartitionFiles(collection: SizeTrackingPairCollection[(Int, K), C]): Unit = { + assert(bypassMergeSort) + + // Create our file writers if we haven't done so yet + if (partitionWriters == null) { + curWriteMetrics = new ShuffleWriteMetrics() + partitionWriters = Array.fill(numPartitions) { + val (blockId, file) = diskBlockManager.createTempBlock() + blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, curWriteMetrics).open() + } + } - spills.append(SpilledFile(file, blockId, batchSizes.toArray, elementsPerPartition)) - _memoryBytesSpilled += memorySize + val it = collection.iterator // No need to sort stuff, just write each element out + while (it.hasNext) { + val elem = it.next() + val partitionId = elem._1._1 + val key = elem._1._2 + val value = elem._2 + partitionWriters(partitionId).write((key, value)) + } } /** @@ -479,7 +551,6 @@ private[spark] class ExternalSorter[K, V, C]( skipToNextPartition() - // Intermediate file and deserializer streams that read from exactly one batch // This guards against pre-fetching and other arbitrary behavior of higher level streams var fileStream: FileInputStream = null @@ -619,23 +690,25 @@ private[spark] class ExternalSorter[K, V, C]( def partitionedIterator: Iterator[(Int, Iterator[Product2[K, C]])] = { val usingMap = aggregator.isDefined val collection: SizeTrackingPairCollection[(Int, K), C] = if (usingMap) map else buffer - if (spills.isEmpty) { + if (spills.isEmpty && partitionWriters == null) { // Special case: if we have only in-memory data, we don't need to merge streams, and perhaps // we don't even need to sort by anything other than partition ID if (!ordering.isDefined) { - // The user isn't requested sorted keys, so only sort by partition ID, not key - val partitionComparator = new Comparator[(Int, K)] { - override def compare(a: (Int, K), b: (Int, K)): Int = { - a._1 - b._1 - } - } + // The user hasn't requested sorted keys, so only sort by partition ID, not key groupByPartition(collection.destructiveSortedIterator(partitionComparator)) } else { // We do need to sort by both partition ID and key groupByPartition(collection.destructiveSortedIterator(partitionKeyComparator)) } + } else if (bypassMergeSort) { + // Read data from each partition file and merge it together with the data in memory; + // note that there's no ordering or aggregator in this case -- we just partition objects + val collIter = groupByPartition(collection.destructiveSortedIterator(partitionComparator)) + collIter.map { case (partitionId, values) => + (partitionId, values ++ readPartitionFile(partitionWriters(partitionId))) + } } else { - // General case: merge spilled and in-memory data + // Merge spilled and in-memory data merge(spills, collection.destructiveSortedIterator(partitionKeyComparator)) } } @@ -645,9 +718,113 @@ private[spark] class ExternalSorter[K, V, C]( */ def iterator: Iterator[Product2[K, C]] = partitionedIterator.flatMap(pair => pair._2) + /** + * Write all the data added into this ExternalSorter into a file in the disk store, creating + * an .index file for it as well with the offsets of each partition. This is called by the + * SortShuffleWriter and can go through an efficient path of just concatenating binary files + * if we decided to avoid merge-sorting. + * + * @param blockId block ID to write to. The index file will be blockId.name + ".index". + * @param context a TaskContext for a running Spark task, for us to update shuffle metrics. + * @return array of lengths, in bytes, of each partition of the file (used by map output tracker) + */ + def writePartitionedFile(blockId: BlockId, context: TaskContext): Array[Long] = { + val outputFile = blockManager.diskBlockManager.getFile(blockId) + + // Track location of each range in the output file + val offsets = new Array[Long](numPartitions + 1) + val lengths = new Array[Long](numPartitions) + + if (bypassMergeSort && partitionWriters != null) { + // We decided to write separate files for each partition, so just concatenate them. To keep + // this simple we spill out the current in-memory collection so that everything is in files. + spillToPartitionFiles(if (aggregator.isDefined) map else buffer) + partitionWriters.foreach(_.commitAndClose()) + var out: FileOutputStream = null + var in: FileInputStream = null + try { + out = new FileOutputStream(outputFile) + for (i <- 0 until numPartitions) { + val file = partitionWriters(i).fileSegment().file + in = new FileInputStream(file) + org.apache.spark.util.Utils.copyStream(in, out) + in.close() + in = null + lengths(i) = file.length() + offsets(i + 1) = offsets(i) + lengths(i) + } + } finally { + if (out != null) { + out.close() + } + if (in != null) { + in.close() + } + } + } else { + // Either we're not bypassing merge-sort or we have only in-memory data; get an iterator by + // partition and just write everything directly. + for ((id, elements) <- this.partitionedIterator) { + if (elements.hasNext) { + val writer = blockManager.getDiskWriter( + blockId, outputFile, ser, fileBufferSize, context.taskMetrics.shuffleWriteMetrics.get) + for (elem <- elements) { + writer.write(elem) + } + writer.commitAndClose() + val segment = writer.fileSegment() + offsets(id + 1) = segment.offset + segment.length + lengths(id) = segment.length + } else { + // The partition is empty; don't create a new writer to avoid writing headers, etc + offsets(id + 1) = offsets(id) + } + } + } + + context.taskMetrics.memoryBytesSpilled += memoryBytesSpilled + context.taskMetrics.diskBytesSpilled += diskBytesSpilled + + // Write an index file with the offsets of each block, plus a final offset at the end for the + // end of the output file. This will be used by SortShuffleManager.getBlockLocation to figure + // out where each block begins and ends. + + val diskBlockManager = blockManager.diskBlockManager + val indexFile = diskBlockManager.getFile(blockId.name + ".index") + val out = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(indexFile))) + try { + var i = 0 + while (i < numPartitions + 1) { + out.writeLong(offsets(i)) + i += 1 + } + } finally { + out.close() + } + + lengths + } + + /** + * Read a partition file back as an iterator (used in our iterator method) + */ + def readPartitionFile(writer: BlockObjectWriter): Iterator[Product2[K, C]] = { + if (writer.isOpen) { + writer.commitAndClose() + } + blockManager.getLocalFromDisk(writer.blockId, ser).get.asInstanceOf[Iterator[Product2[K, C]]] + } + def stop(): Unit = { spills.foreach(s => s.file.delete()) spills.clear() + if (partitionWriters != null) { + partitionWriters.foreach { w => + w.revertPartialWritesAndClose() + diskBlockManager.getFile(w.blockId).delete() + } + partitionWriters = null + } } def memoryBytesSpilled: Long = _memoryBytesSpilled diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala index 57dcb4ffabac1..706faed980f31 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala @@ -19,12 +19,12 @@ package org.apache.spark.util.collection import scala.collection.mutable.ArrayBuffer -import org.scalatest.FunSuite +import org.scalatest.{PrivateMethodTester, FunSuite} import org.apache.spark._ import org.apache.spark.SparkContext._ -class ExternalSorterSuite extends FunSuite with LocalSparkContext { +class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMethodTester { private def createSparkConf(loadDefaults: Boolean): SparkConf = { val conf = new SparkConf(loadDefaults) // Make the Java serializer write a reset instruction (TC_RESET) after each object to test @@ -36,6 +36,16 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext { conf } + private def assertBypassedMergeSort(sorter: ExternalSorter[_, _, _]): Unit = { + val bypassMergeSort = PrivateMethod[Boolean]('bypassMergeSort) + assert(sorter.invokePrivate(bypassMergeSort()), "sorter did not bypass merge-sort") + } + + private def assertDidNotBypassMergeSort(sorter: ExternalSorter[_, _, _]): Unit = { + val bypassMergeSort = PrivateMethod[Boolean]('bypassMergeSort) + assert(!sorter.invokePrivate(bypassMergeSort()), "sorter bypassed merge-sort") + } + test("empty data stream") { val conf = new SparkConf(false) conf.set("spark.shuffle.memoryFraction", "0.001") @@ -86,28 +96,28 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext { // Both aggregator and ordering val sorter = new ExternalSorter[Int, Int, Int]( Some(agg), Some(new HashPartitioner(7)), Some(ord), None) - sorter.write(elements.iterator) + sorter.insertAll(elements.iterator) assert(sorter.partitionedIterator.map(p => (p._1, p._2.toSet)).toSet === expected) sorter.stop() // Only aggregator val sorter2 = new ExternalSorter[Int, Int, Int]( Some(agg), Some(new HashPartitioner(7)), None, None) - sorter2.write(elements.iterator) + sorter2.insertAll(elements.iterator) assert(sorter2.partitionedIterator.map(p => (p._1, p._2.toSet)).toSet === expected) sorter2.stop() // Only ordering val sorter3 = new ExternalSorter[Int, Int, Int]( None, Some(new HashPartitioner(7)), Some(ord), None) - sorter3.write(elements.iterator) + sorter3.insertAll(elements.iterator) assert(sorter3.partitionedIterator.map(p => (p._1, p._2.toSet)).toSet === expected) sorter3.stop() // Neither aggregator nor ordering val sorter4 = new ExternalSorter[Int, Int, Int]( None, Some(new HashPartitioner(7)), None, None) - sorter4.write(elements.iterator) + sorter4.insertAll(elements.iterator) assert(sorter4.partitionedIterator.map(p => (p._1, p._2.toSet)).toSet === expected) sorter4.stop() } @@ -118,13 +128,37 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext { conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") sc = new SparkContext("local", "test", conf) - val agg = new Aggregator[Int, Int, Int](i => i, (i, j) => i + j, (i, j) => i + j) val ord = implicitly[Ordering[Int]] val elements = Iterator((1, 1), (5, 5)) ++ (0 until 100000).iterator.map(x => (2, 2)) + val sorter = new ExternalSorter[Int, Int, Int]( + None, Some(new HashPartitioner(7)), Some(ord), None) + assertDidNotBypassMergeSort(sorter) + sorter.insertAll(elements) + assert(sc.env.blockManager.diskBlockManager.getAllFiles().length > 0) // Make sure it spilled + val iter = sorter.partitionedIterator.map(p => (p._1, p._2.toList)) + assert(iter.next() === (0, Nil)) + assert(iter.next() === (1, List((1, 1)))) + assert(iter.next() === (2, (0 until 100000).map(x => (2, 2)).toList)) + assert(iter.next() === (3, Nil)) + assert(iter.next() === (4, Nil)) + assert(iter.next() === (5, List((5, 5)))) + assert(iter.next() === (6, Nil)) + sorter.stop() + } + + test("empty partitions with spilling, bypass merge-sort") { + val conf = createSparkConf(false) + conf.set("spark.shuffle.memoryFraction", "0.001") + conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") + sc = new SparkContext("local", "test", conf) + + val elements = Iterator((1, 1), (5, 5)) ++ (0 until 100000).iterator.map(x => (2, 2)) + val sorter = new ExternalSorter[Int, Int, Int]( None, Some(new HashPartitioner(7)), None, None) - sorter.write(elements) + assertBypassedMergeSort(sorter) + sorter.insertAll(elements) assert(sc.env.blockManager.diskBlockManager.getAllFiles().length > 0) // Make sure it spilled val iter = sorter.partitionedIterator.map(p => (p._1, p._2.toList)) assert(iter.next() === (0, Nil)) @@ -286,14 +320,43 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext { sc = new SparkContext("local", "test", conf) val diskBlockManager = SparkEnv.get.blockManager.diskBlockManager + val ord = implicitly[Ordering[Int]] + + val sorter = new ExternalSorter[Int, Int, Int]( + None, Some(new HashPartitioner(3)), Some(ord), None) + assertDidNotBypassMergeSort(sorter) + sorter.insertAll((0 until 100000).iterator.map(i => (i, i))) + assert(diskBlockManager.getAllFiles().length > 0) + sorter.stop() + assert(diskBlockManager.getAllBlocks().length === 0) + + val sorter2 = new ExternalSorter[Int, Int, Int]( + None, Some(new HashPartitioner(3)), Some(ord), None) + assertDidNotBypassMergeSort(sorter2) + sorter2.insertAll((0 until 100000).iterator.map(i => (i, i))) + assert(diskBlockManager.getAllFiles().length > 0) + assert(sorter2.iterator.toSet === (0 until 100000).map(i => (i, i)).toSet) + sorter2.stop() + assert(diskBlockManager.getAllBlocks().length === 0) + } + + test("cleanup of intermediate files in sorter, bypass merge-sort") { + val conf = createSparkConf(true) // Load defaults, otherwise SPARK_HOME is not found + conf.set("spark.shuffle.memoryFraction", "0.001") + conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") + sc = new SparkContext("local", "test", conf) + val diskBlockManager = SparkEnv.get.blockManager.diskBlockManager + val sorter = new ExternalSorter[Int, Int, Int](None, Some(new HashPartitioner(3)), None, None) - sorter.write((0 until 100000).iterator.map(i => (i, i))) + assertBypassedMergeSort(sorter) + sorter.insertAll((0 until 100000).iterator.map(i => (i, i))) assert(diskBlockManager.getAllFiles().length > 0) sorter.stop() assert(diskBlockManager.getAllBlocks().length === 0) val sorter2 = new ExternalSorter[Int, Int, Int](None, Some(new HashPartitioner(3)), None, None) - sorter2.write((0 until 100000).iterator.map(i => (i, i))) + assertBypassedMergeSort(sorter2) + sorter2.insertAll((0 until 100000).iterator.map(i => (i, i))) assert(diskBlockManager.getAllFiles().length > 0) assert(sorter2.iterator.toSet === (0 until 100000).map(i => (i, i)).toSet) sorter2.stop() @@ -307,9 +370,35 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext { sc = new SparkContext("local", "test", conf) val diskBlockManager = SparkEnv.get.blockManager.diskBlockManager + val ord = implicitly[Ordering[Int]] + + val sorter = new ExternalSorter[Int, Int, Int]( + None, Some(new HashPartitioner(3)), Some(ord), None) + assertDidNotBypassMergeSort(sorter) + intercept[SparkException] { + sorter.insertAll((0 until 100000).iterator.map(i => { + if (i == 99990) { + throw new SparkException("Intentional failure") + } + (i, i) + })) + } + assert(diskBlockManager.getAllFiles().length > 0) + sorter.stop() + assert(diskBlockManager.getAllBlocks().length === 0) + } + + test("cleanup of intermediate files in sorter if there are errors, bypass merge-sort") { + val conf = createSparkConf(true) // Load defaults, otherwise SPARK_HOME is not found + conf.set("spark.shuffle.memoryFraction", "0.001") + conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") + sc = new SparkContext("local", "test", conf) + val diskBlockManager = SparkEnv.get.blockManager.diskBlockManager + val sorter = new ExternalSorter[Int, Int, Int](None, Some(new HashPartitioner(3)), None, None) + assertBypassedMergeSort(sorter) intercept[SparkException] { - sorter.write((0 until 100000).iterator.map(i => { + sorter.insertAll((0 until 100000).iterator.map(i => { if (i == 99990) { throw new SparkException("Intentional failure") } @@ -365,7 +454,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext { sc = new SparkContext("local", "test", conf) val sorter = new ExternalSorter[Int, Int, Int](None, Some(new HashPartitioner(3)), None, None) - sorter.write((0 until 100000).iterator.map(i => (i / 4, i))) + sorter.insertAll((0 until 100000).iterator.map(i => (i / 4, i))) val results = sorter.partitionedIterator.map{case (p, vs) => (p, vs.toSet)}.toSet val expected = (0 until 3).map(p => { (p, (0 until 100000).map(i => (i / 4, i)).filter(_._1 % 3 == p).toSet) @@ -381,7 +470,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext { val agg = new Aggregator[Int, Int, Int](i => i, (i, j) => i + j, (i, j) => i + j) val sorter = new ExternalSorter(Some(agg), Some(new HashPartitioner(3)), None, None) - sorter.write((0 until 100).iterator.map(i => (i / 2, i))) + sorter.insertAll((0 until 100).iterator.map(i => (i / 2, i))) val results = sorter.partitionedIterator.map{case (p, vs) => (p, vs.toSet)}.toSet val expected = (0 until 3).map(p => { (p, (0 until 50).map(i => (i, i * 4 + 1)).filter(_._1 % 3 == p).toSet) @@ -397,7 +486,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext { val agg = new Aggregator[Int, Int, Int](i => i, (i, j) => i + j, (i, j) => i + j) val sorter = new ExternalSorter(Some(agg), Some(new HashPartitioner(3)), None, None) - sorter.write((0 until 100000).iterator.map(i => (i / 2, i))) + sorter.insertAll((0 until 100000).iterator.map(i => (i / 2, i))) val results = sorter.partitionedIterator.map{case (p, vs) => (p, vs.toSet)}.toSet val expected = (0 until 3).map(p => { (p, (0 until 50000).map(i => (i, i * 4 + 1)).filter(_._1 % 3 == p).toSet) @@ -414,7 +503,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext { val agg = new Aggregator[Int, Int, Int](i => i, (i, j) => i + j, (i, j) => i + j) val ord = implicitly[Ordering[Int]] val sorter = new ExternalSorter(Some(agg), Some(new HashPartitioner(3)), Some(ord), None) - sorter.write((0 until 100000).iterator.map(i => (i / 2, i))) + sorter.insertAll((0 until 100000).iterator.map(i => (i / 2, i))) val results = sorter.partitionedIterator.map{case (p, vs) => (p, vs.toSet)}.toSet val expected = (0 until 3).map(p => { (p, (0 until 50000).map(i => (i, i * 4 + 1)).filter(_._1 % 3 == p).toSet) @@ -431,7 +520,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext { val ord = implicitly[Ordering[Int]] val sorter = new ExternalSorter[Int, Int, Int]( None, Some(new HashPartitioner(3)), Some(ord), None) - sorter.write((0 until 100).iterator.map(i => (i, i))) + sorter.insertAll((0 until 100).iterator.map(i => (i, i))) val results = sorter.partitionedIterator.map{case (p, vs) => (p, vs.toSeq)}.toSeq val expected = (0 until 3).map(p => { (p, (0 until 100).map(i => (i, i)).filter(_._1 % 3 == p).toSeq) @@ -448,7 +537,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext { val ord = implicitly[Ordering[Int]] val sorter = new ExternalSorter[Int, Int, Int]( None, Some(new HashPartitioner(3)), Some(ord), None) - sorter.write((0 until 100000).iterator.map(i => (i, i))) + sorter.insertAll((0 until 100000).iterator.map(i => (i, i))) val results = sorter.partitionedIterator.map{case (p, vs) => (p, vs.toSeq)}.toSeq val expected = (0 until 3).map(p => { (p, (0 until 100000).map(i => (i, i)).filter(_._1 % 3 == p).toSeq) @@ -495,7 +584,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext { val toInsert = (1 to 100000).iterator.map(_.toString).map(s => (s, s)) ++ collisionPairs.iterator ++ collisionPairs.iterator.map(_.swap) - sorter.write(toInsert) + sorter.insertAll(toInsert) // A map of collision pairs in both directions val collisionPairsMap = (collisionPairs ++ collisionPairs.map(_.swap)).toMap @@ -524,7 +613,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext { // Insert 10 copies each of lots of objects whose hash codes are either 0 or 1. This causes // problems if the map fails to group together the objects with the same code (SPARK-2043). val toInsert = for (i <- 1 to 10; j <- 1 to 10000) yield (FixedHashObject(j, j % 2), 1) - sorter.write(toInsert.iterator) + sorter.insertAll(toInsert.iterator) val it = sorter.iterator var count = 0 @@ -548,7 +637,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext { val agg = new Aggregator[Int, Int, ArrayBuffer[Int]](createCombiner, mergeValue, mergeCombiners) val sorter = new ExternalSorter[Int, Int, ArrayBuffer[Int]](Some(agg), None, None, None) - sorter.write((1 to 100000).iterator.map(i => (i, i)) ++ Iterator((Int.MaxValue, Int.MaxValue))) + sorter.insertAll((1 to 100000).iterator.map(i => (i, i)) ++ Iterator((Int.MaxValue, Int.MaxValue))) val it = sorter.iterator while (it.hasNext) { @@ -572,7 +661,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext { val sorter = new ExternalSorter[String, String, ArrayBuffer[String]]( Some(agg), None, None, None) - sorter.write((1 to 100000).iterator.map(i => (i.toString, i.toString)) ++ Iterator( + sorter.insertAll((1 to 100000).iterator.map(i => (i.toString, i.toString)) ++ Iterator( (null.asInstanceOf[String], "1"), ("1", null.asInstanceOf[String]), (null.asInstanceOf[String], null.asInstanceOf[String]) @@ -584,4 +673,38 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext { it.next() } } + + test("conditions for bypassing merge-sort") { + val conf = createSparkConf(false) + conf.set("spark.shuffle.memoryFraction", "0.001") + conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") + sc = new SparkContext("local", "test", conf) + + val agg = new Aggregator[Int, Int, Int](i => i, (i, j) => i + j, (i, j) => i + j) + val ord = implicitly[Ordering[Int]] + + // Numbers of partitions that are above and below the default bypassMergeThreshold + val FEW_PARTITIONS = 50 + val MANY_PARTITIONS = 10000 + + // Sorters with no ordering or aggregator: should bypass unless # of partitions is high + + val sorter1 = new ExternalSorter[Int, Int, Int]( + None, Some(new HashPartitioner(FEW_PARTITIONS)), None, None) + assertBypassedMergeSort(sorter1) + + val sorter2 = new ExternalSorter[Int, Int, Int]( + None, Some(new HashPartitioner(MANY_PARTITIONS)), None, None) + assertDidNotBypassMergeSort(sorter2) + + // Sorters with an ordering or aggregator: should not bypass even if they have few partitions + + val sorter3 = new ExternalSorter[Int, Int, Int]( + None, Some(new HashPartitioner(FEW_PARTITIONS)), Some(ord), None) + assertDidNotBypassMergeSort(sorter3) + + val sorter4 = new ExternalSorter[Int, Int, Int]( + Some(agg), Some(new HashPartitioner(FEW_PARTITIONS)), None, None) + assertDidNotBypassMergeSort(sorter4) + } } diff --git a/docs/configuration.md b/docs/configuration.md index 5e3eb0f0871af..4d27c5a918fe0 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -281,6 +281,24 @@ Apart from these, the following properties are also available, and may be useful overhead per reduce task, so keep it small unless you have a large amount of memory. + + spark.shuffle.manager + HASH + + Implementation to use for shuffling data. A hash-based shuffle manager is the default, but + starting in Spark 1.1 there is an experimental sort-based shuffle manager that is more + memory-efficient in environments with small executors, such as YARN. To use that, change + this value to SORT. + + + + spark.shuffle.sort.bypassMergeThreshold + 200 + + (Advanced) In the sort-based shuffle manager, avoid merge-sorting data if there is no + map-side aggregation and there are at most this many reduce partitions. + + #### Spark UI From aab7735d3162a4286cfbdb078c781d0326e074ad Mon Sep 17 00:00:00 2001 From: Sandy Ryza Date: Thu, 7 Aug 2014 18:09:03 -0700 Subject: [PATCH 077/231] SPARK-2565. Update ShuffleReadMetrics as blocks are fetched Author: Sandy Ryza Closes #1507 from sryza/sandy-spark-2565 and squashes the following commits: 74dad41 [Sandy Ryza] SPARK-2565. Update ShuffleReadMetrics as blocks are fetched (cherry picked from commit 4c51098f320f164eb66f92ff0f26b0b595a58f38) Signed-off-by: Patrick Wendell --- .../org/apache/spark/executor/Executor.scala | 1 + .../apache/spark/executor/TaskMetrics.scala | 55 ++++++++++++++----- .../hash/BlockStoreShuffleFetcher.scala | 13 ++--- .../shuffle/hash/HashShuffleReader.scala | 4 +- .../spark/storage/BlockFetcherIterator.scala | 40 +++++--------- .../apache/spark/storage/BlockManager.scala | 11 ++-- .../org/apache/spark/util/JsonProtocol.scala | 5 +- .../storage/BlockFetcherIteratorSuite.scala | 13 +++-- .../ui/jobs/JobProgressListenerSuite.scala | 4 +- .../apache/spark/util/JsonProtocolSuite.scala | 2 +- 10 files changed, 84 insertions(+), 64 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index c2b9c660ddaec..eac1f2326a29d 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -374,6 +374,7 @@ private[spark] class Executor( for (taskRunner <- runningTasks.values()) { if (!taskRunner.attemptedTask.isEmpty) { Option(taskRunner.task).flatMap(_.metrics).foreach { metrics => + metrics.updateShuffleReadMetrics tasksMetrics += ((taskRunner.taskId, metrics)) } } diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala index 11a6e10243211..99a88c13456df 100644 --- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala @@ -17,6 +17,8 @@ package org.apache.spark.executor +import scala.collection.mutable.ArrayBuffer + import org.apache.spark.annotation.DeveloperApi import org.apache.spark.storage.{BlockId, BlockStatus} @@ -81,12 +83,27 @@ class TaskMetrics extends Serializable { var inputMetrics: Option[InputMetrics] = None /** - * If this task reads from shuffle output, metrics on getting shuffle data will be collected here + * If this task reads from shuffle output, metrics on getting shuffle data will be collected here. + * This includes read metrics aggregated over all the task's shuffle dependencies. */ private var _shuffleReadMetrics: Option[ShuffleReadMetrics] = None def shuffleReadMetrics = _shuffleReadMetrics + /** + * This should only be used when recreating TaskMetrics, not when updating read metrics in + * executors. + */ + private[spark] def setShuffleReadMetrics(shuffleReadMetrics: Option[ShuffleReadMetrics]) { + _shuffleReadMetrics = shuffleReadMetrics + } + + /** + * ShuffleReadMetrics per dependency for collecting independently while task is in progress. + */ + @transient private lazy val depsShuffleReadMetrics: ArrayBuffer[ShuffleReadMetrics] = + new ArrayBuffer[ShuffleReadMetrics]() + /** * If this task writes to shuffle output, metrics on the written shuffle data will be collected * here @@ -98,19 +115,31 @@ class TaskMetrics extends Serializable { */ var updatedBlocks: Option[Seq[(BlockId, BlockStatus)]] = None - /** Adds the given ShuffleReadMetrics to any existing shuffle metrics for this task. */ - def updateShuffleReadMetrics(newMetrics: ShuffleReadMetrics) = synchronized { - _shuffleReadMetrics match { - case Some(existingMetrics) => - existingMetrics.shuffleFinishTime = math.max( - existingMetrics.shuffleFinishTime, newMetrics.shuffleFinishTime) - existingMetrics.fetchWaitTime += newMetrics.fetchWaitTime - existingMetrics.localBlocksFetched += newMetrics.localBlocksFetched - existingMetrics.remoteBlocksFetched += newMetrics.remoteBlocksFetched - existingMetrics.remoteBytesRead += newMetrics.remoteBytesRead - case None => - _shuffleReadMetrics = Some(newMetrics) + /** + * A task may have multiple shuffle readers for multiple dependencies. To avoid synchronization + * issues from readers in different threads, in-progress tasks use a ShuffleReadMetrics for each + * dependency, and merge these metrics before reporting them to the driver. This method returns + * a ShuffleReadMetrics for a dependency and registers it for merging later. + */ + private [spark] def createShuffleReadMetricsForDependency(): ShuffleReadMetrics = synchronized { + val readMetrics = new ShuffleReadMetrics() + depsShuffleReadMetrics += readMetrics + readMetrics + } + + /** + * Aggregates shuffle read metrics for all registered dependencies into shuffleReadMetrics. + */ + private[spark] def updateShuffleReadMetrics() = synchronized { + val merged = new ShuffleReadMetrics() + for (depMetrics <- depsShuffleReadMetrics) { + merged.fetchWaitTime += depMetrics.fetchWaitTime + merged.localBlocksFetched += depMetrics.localBlocksFetched + merged.remoteBlocksFetched += depMetrics.remoteBlocksFetched + merged.remoteBytesRead += depMetrics.remoteBytesRead + merged.shuffleFinishTime = math.max(merged.shuffleFinishTime, depMetrics.shuffleFinishTime) } + _shuffleReadMetrics = Some(merged) } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala index 99788828981c7..12b475658e29d 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala @@ -32,7 +32,8 @@ private[hash] object BlockStoreShuffleFetcher extends Logging { shuffleId: Int, reduceId: Int, context: TaskContext, - serializer: Serializer) + serializer: Serializer, + shuffleMetrics: ShuffleReadMetrics) : Iterator[T] = { logDebug("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId)) @@ -73,17 +74,11 @@ private[hash] object BlockStoreShuffleFetcher extends Logging { } } - val blockFetcherItr = blockManager.getMultiple(blocksByAddress, serializer) + val blockFetcherItr = blockManager.getMultiple(blocksByAddress, serializer, shuffleMetrics) val itr = blockFetcherItr.flatMap(unpackBlock) val completionIter = CompletionIterator[T, Iterator[T]](itr, { - val shuffleMetrics = new ShuffleReadMetrics - shuffleMetrics.shuffleFinishTime = System.currentTimeMillis - shuffleMetrics.fetchWaitTime = blockFetcherItr.fetchWaitTime - shuffleMetrics.remoteBytesRead = blockFetcherItr.remoteBytesRead - shuffleMetrics.localBlocksFetched = blockFetcherItr.numLocalBlocks - shuffleMetrics.remoteBlocksFetched = blockFetcherItr.numRemoteBlocks - context.taskMetrics.updateShuffleReadMetrics(shuffleMetrics) + context.taskMetrics.updateShuffleReadMetrics() }) new InterruptibleIterator[T](context, completionIter) diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala index 88a5f1e5ddf58..7bed97a63f0f6 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala @@ -36,8 +36,10 @@ private[spark] class HashShuffleReader[K, C]( /** Read the combined key-values for this reduce task */ override def read(): Iterator[Product2[K, C]] = { + val readMetrics = context.taskMetrics.createShuffleReadMetricsForDependency() val ser = Serializer.getSerializer(dep.serializer) - val iter = BlockStoreShuffleFetcher.fetch(handle.shuffleId, startPartition, context, ser) + val iter = BlockStoreShuffleFetcher.fetch(handle.shuffleId, startPartition, context, ser, + readMetrics) val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) { if (dep.mapSideCombine) { diff --git a/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala index 938af6f5b923a..5f44f5f3197fd 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala @@ -27,6 +27,7 @@ import scala.util.{Failure, Success} import io.netty.buffer.ByteBuf import org.apache.spark.{Logging, SparkException} +import org.apache.spark.executor.ShuffleReadMetrics import org.apache.spark.network.BufferMessage import org.apache.spark.network.ConnectionManagerId import org.apache.spark.network.netty.ShuffleCopier @@ -47,10 +48,6 @@ import org.apache.spark.util.Utils private[storage] trait BlockFetcherIterator extends Iterator[(BlockId, Option[Iterator[Any]])] with Logging { def initialize() - def numLocalBlocks: Int - def numRemoteBlocks: Int - def fetchWaitTime: Long - def remoteBytesRead: Long } @@ -72,14 +69,12 @@ object BlockFetcherIterator { class BasicBlockFetcherIterator( private val blockManager: BlockManager, val blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])], - serializer: Serializer) + serializer: Serializer, + readMetrics: ShuffleReadMetrics) extends BlockFetcherIterator { import blockManager._ - private var _remoteBytesRead = 0L - private var _fetchWaitTime = 0L - if (blocksByAddress == null) { throw new IllegalArgumentException("BlocksByAddress is null") } @@ -89,13 +84,9 @@ object BlockFetcherIterator { protected var startTime = System.currentTimeMillis - // This represents the number of local blocks, also counting zero-sized blocks - private var numLocal = 0 // BlockIds for local blocks that need to be fetched. Excludes zero-sized blocks protected val localBlocksToFetch = new ArrayBuffer[BlockId]() - // This represents the number of remote blocks, also counting zero-sized blocks - private var numRemote = 0 // BlockIds for remote blocks that need to be fetched. Excludes zero-sized blocks protected val remoteBlocksToFetch = new HashSet[BlockId]() @@ -132,7 +123,10 @@ object BlockFetcherIterator { val networkSize = blockMessage.getData.limit() results.put(new FetchResult(blockId, sizeMap(blockId), () => dataDeserialize(blockId, blockMessage.getData, serializer))) - _remoteBytesRead += networkSize + // TODO: NettyBlockFetcherIterator has some race conditions where multiple threads can + // be incrementing bytes read at the same time (SPARK-2625). + readMetrics.remoteBytesRead += networkSize + readMetrics.remoteBlocksFetched += 1 logDebug("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime)) } } @@ -155,14 +149,14 @@ object BlockFetcherIterator { // Split local and remote blocks. Remote blocks are further split into FetchRequests of size // at most maxBytesInFlight in order to limit the amount of data in flight. val remoteRequests = new ArrayBuffer[FetchRequest] + var totalBlocks = 0 for ((address, blockInfos) <- blocksByAddress) { + totalBlocks += blockInfos.size if (address == blockManagerId) { - numLocal = blockInfos.size // Filter out zero-sized blocks localBlocksToFetch ++= blockInfos.filter(_._2 != 0).map(_._1) _numBlocksToFetch += localBlocksToFetch.size } else { - numRemote += blockInfos.size val iterator = blockInfos.iterator var curRequestSize = 0L var curBlocks = new ArrayBuffer[(BlockId, Long)] @@ -192,7 +186,7 @@ object BlockFetcherIterator { } } logInfo("Getting " + _numBlocksToFetch + " non-empty blocks out of " + - (numLocal + numRemote) + " blocks") + totalBlocks + " blocks") remoteRequests } @@ -205,6 +199,7 @@ object BlockFetcherIterator { // getLocalFromDisk never return None but throws BlockException val iter = getLocalFromDisk(id, serializer).get // Pass 0 as size since it's not in flight + readMetrics.localBlocksFetched += 1 results.put(new FetchResult(id, 0, () => iter)) logDebug("Got local block " + id) } catch { @@ -238,12 +233,6 @@ object BlockFetcherIterator { logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms") } - override def numLocalBlocks: Int = numLocal - override def numRemoteBlocks: Int = numRemote - override def fetchWaitTime: Long = _fetchWaitTime - override def remoteBytesRead: Long = _remoteBytesRead - - // Implementing the Iterator methods with an iterator that reads fetched blocks off the queue // as they arrive. @volatile protected var resultsGotten = 0 @@ -255,7 +244,7 @@ object BlockFetcherIterator { val startFetchWait = System.currentTimeMillis() val result = results.take() val stopFetchWait = System.currentTimeMillis() - _fetchWaitTime += (stopFetchWait - startFetchWait) + readMetrics.fetchWaitTime += (stopFetchWait - startFetchWait) if (! result.failed) bytesInFlight -= result.size while (!fetchRequests.isEmpty && (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) { @@ -269,8 +258,9 @@ object BlockFetcherIterator { class NettyBlockFetcherIterator( blockManager: BlockManager, blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])], - serializer: Serializer) - extends BasicBlockFetcherIterator(blockManager, blocksByAddress, serializer) { + serializer: Serializer, + readMetrics: ShuffleReadMetrics) + extends BasicBlockFetcherIterator(blockManager, blocksByAddress, serializer, readMetrics) { import blockManager._ diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 8d21b02b747ff..e8bbd298c631a 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -29,7 +29,7 @@ import akka.actor.{ActorSystem, Cancellable, Props} import sun.nio.ch.DirectBuffer import org.apache.spark._ -import org.apache.spark.executor.{DataReadMethod, InputMetrics, ShuffleWriteMetrics} +import org.apache.spark.executor._ import org.apache.spark.io.CompressionCodec import org.apache.spark.network._ import org.apache.spark.serializer.Serializer @@ -539,12 +539,15 @@ private[spark] class BlockManager( */ def getMultiple( blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])], - serializer: Serializer): BlockFetcherIterator = { + serializer: Serializer, + readMetrics: ShuffleReadMetrics): BlockFetcherIterator = { val iter = if (conf.getBoolean("spark.shuffle.use.netty", false)) { - new BlockFetcherIterator.NettyBlockFetcherIterator(this, blocksByAddress, serializer) + new BlockFetcherIterator.NettyBlockFetcherIterator(this, blocksByAddress, serializer, + readMetrics) } else { - new BlockFetcherIterator.BasicBlockFetcherIterator(this, blocksByAddress, serializer) + new BlockFetcherIterator.BasicBlockFetcherIterator(this, blocksByAddress, serializer, + readMetrics) } iter.initialize() iter diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index b112b359368cd..6f8eb1ee12634 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -560,9 +560,8 @@ private[spark] object JsonProtocol { metrics.resultSerializationTime = (json \ "Result Serialization Time").extract[Long] metrics.memoryBytesSpilled = (json \ "Memory Bytes Spilled").extract[Long] metrics.diskBytesSpilled = (json \ "Disk Bytes Spilled").extract[Long] - Utils.jsonOption(json \ "Shuffle Read Metrics").map { shuffleReadMetrics => - metrics.updateShuffleReadMetrics(shuffleReadMetricsFromJson(shuffleReadMetrics)) - } + metrics.setShuffleReadMetrics( + Utils.jsonOption(json \ "Shuffle Read Metrics").map(shuffleReadMetricsFromJson)) metrics.shuffleWriteMetrics = Utils.jsonOption(json \ "Shuffle Write Metrics").map(shuffleWriteMetricsFromJson) metrics.inputMetrics = diff --git a/core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala index 1538995a6b404..bcbfe8baf36ad 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala @@ -33,6 +33,7 @@ import org.mockito.invocation.InvocationOnMock import org.apache.spark.storage.BlockFetcherIterator._ import org.apache.spark.network.{ConnectionManager, Message} +import org.apache.spark.executor.ShuffleReadMetrics class BlockFetcherIteratorSuite extends FunSuite with Matchers { @@ -70,8 +71,8 @@ class BlockFetcherIteratorSuite extends FunSuite with Matchers { (bmId, blIds.map(blId => (blId, 1.asInstanceOf[Long])).toSeq) ) - val iterator = new BasicBlockFetcherIterator(blockManager, - blocksByAddress, null) + val iterator = new BasicBlockFetcherIterator(blockManager, blocksByAddress, null, + new ShuffleReadMetrics()) iterator.initialize() @@ -121,8 +122,8 @@ class BlockFetcherIteratorSuite extends FunSuite with Matchers { (bmId, blIds.map(blId => (blId, 1.asInstanceOf[Long])).toSeq) ) - val iterator = new BasicBlockFetcherIterator(blockManager, - blocksByAddress, null) + val iterator = new BasicBlockFetcherIterator(blockManager, blocksByAddress, null, + new ShuffleReadMetrics()) iterator.initialize() @@ -165,7 +166,7 @@ class BlockFetcherIteratorSuite extends FunSuite with Matchers { ) val iterator = new BasicBlockFetcherIterator(blockManager, - blocksByAddress, null) + blocksByAddress, null, new ShuffleReadMetrics()) iterator.initialize() iterator.foreach{ @@ -219,7 +220,7 @@ class BlockFetcherIteratorSuite extends FunSuite with Matchers { ) val iterator = new BasicBlockFetcherIterator(blockManager, - blocksByAddress, null) + blocksByAddress, null, new ShuffleReadMetrics()) iterator.initialize() iterator.foreach{ case (_, r) => { diff --git a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala index cb8252515238e..f5ba31c309277 100644 --- a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala @@ -65,7 +65,7 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc // finish this task, should get updated shuffleRead shuffleReadMetrics.remoteBytesRead = 1000 - taskMetrics.updateShuffleReadMetrics(shuffleReadMetrics) + taskMetrics.setShuffleReadMetrics(Some(shuffleReadMetrics)) var taskInfo = new TaskInfo(1234L, 0, 1, 0L, "exe-1", "host1", TaskLocality.NODE_LOCAL, false) taskInfo.finishTime = 1 var task = new ShuffleMapTask(0) @@ -142,7 +142,7 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc val taskMetrics = new TaskMetrics() val shuffleReadMetrics = new ShuffleReadMetrics() val shuffleWriteMetrics = new ShuffleWriteMetrics() - taskMetrics.updateShuffleReadMetrics(shuffleReadMetrics) + taskMetrics.setShuffleReadMetrics(Some(shuffleReadMetrics)) taskMetrics.shuffleWriteMetrics = Some(shuffleWriteMetrics) shuffleReadMetrics.remoteBytesRead = base + 1 shuffleReadMetrics.remoteBlocksFetched = base + 2 diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index 2002a817d9168..97ffb07662482 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -539,7 +539,7 @@ class JsonProtocolSuite extends FunSuite { sr.localBlocksFetched = e sr.fetchWaitTime = a + d sr.remoteBlocksFetched = f - t.updateShuffleReadMetrics(sr) + t.setShuffleReadMetrics(Some(sr)) } sw.shuffleBytesWritten = a + b + c sw.shuffleWriteTime = b + c + d From a54b5d955df151562721fc04b438337d15ab1dec Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Thu, 7 Aug 2014 18:53:15 -0700 Subject: [PATCH 078/231] [SPARK-2904] Remove non-used local variable in SparkSubmitArguments Author: Kousuke Saruta Closes #1834 from sarutak/SPARK-2904 and squashes the following commits: 38e7d45 [Kousuke Saruta] Removed non-used variable in SparkSubmitArguments (cherry picked from commit 9de6a42bb34ea8963225ce90f1a45adcfee38b58) Signed-off-by: Patrick Wendell --- .../scala/org/apache/spark/deploy/SparkSubmitArguments.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index 087dd4d633db0..c21f1529a1837 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -219,7 +219,6 @@ private[spark] class SparkSubmitArguments(args: Seq[String]) { /** Fill in values by parsing user options. */ private def parseOpts(opts: Seq[String]): Unit = { - var inSparkOpts = true val EQ_SEPARATED_OPT="""(--[^=]+)=(.+)""".r // Delineates parsing of Spark options from parsing of user options. From 3eb5dd043427de8c050687231011863b22feecdb Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Fri, 8 Aug 2014 11:01:51 -0700 Subject: [PATCH 079/231] [SPARK-2888] [SQL] Fix addColumnMetadataToConf in HiveTableScan JIRA: https://issues.apache.org/jira/browse/SPARK-2888 Author: Yin Huai Closes #1817 from yhuai/fixAddColumnMetadataToConf and squashes the following commits: fba728c [Yin Huai] Fix addColumnMetadataToConf. (cherry picked from commit 9016af3f2729101027e33593e094332f05f48d92) Signed-off-by: Michael Armbrust --- .../sql/hive/execution/HiveTableScan.scala | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala index 8920e2a76a27f..577ca928b43b6 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala @@ -72,17 +72,12 @@ case class HiveTableScan( } private def addColumnMetadataToConf(hiveConf: HiveConf) { - // Specifies IDs and internal names of columns to be scanned. - val neededColumnIDs = attributes.map(a => relation.output.indexWhere(_.name == a.name): Integer) - val columnInternalNames = neededColumnIDs.map(HiveConf.getColumnInternalName(_)).mkString(",") - - if (attributes.size == relation.output.size) { - // SQLContext#pruneFilterProject guarantees no duplicated value in `attributes` - ColumnProjectionUtils.setFullyReadColumns(hiveConf) - } else { - ColumnProjectionUtils.appendReadColumnIDs(hiveConf, neededColumnIDs) - } + // Specifies needed column IDs for those non-partitioning columns. + val neededColumnIDs = + attributes.map(a => + relation.attributes.indexWhere(_.name == a.name): Integer).filter(index => index >= 0) + ColumnProjectionUtils.appendReadColumnIDs(hiveConf, neededColumnIDs) ColumnProjectionUtils.appendReadColumnNames(hiveConf, attributes.map(_.name)) // Specifies types and object inspectors of columns to be scanned. @@ -99,7 +94,7 @@ case class HiveTableScan( .mkString(",") hiveConf.set(serdeConstants.LIST_COLUMN_TYPES, columnTypeNames) - hiveConf.set(serdeConstants.LIST_COLUMNS, columnInternalNames) + hiveConf.set(serdeConstants.LIST_COLUMNS, relation.attributes.map(_.name).mkString(",")) } addColumnMetadataToConf(context.hiveconf) From 544a909ccd99d9a3c6ac2f21bd1802c18f7b950a Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Fri, 8 Aug 2014 11:10:11 -0700 Subject: [PATCH 080/231] [SPARK-2908] [SQL] JsonRDD.nullTypeToStringType does not convert all NullType to StringType JIRA: https://issues.apache.org/jira/browse/SPARK-2908 Author: Yin Huai Closes #1840 from yhuai/SPARK-2908 and squashes the following commits: 86e833e [Yin Huai] Update test. cb11759 [Yin Huai] nullTypeToStringType should check columns with the type of array of structs. (cherry picked from commit 0489cee6b24ca34f1adab03a75d157e04a9e06b7) Signed-off-by: Michael Armbrust --- .../scala/org/apache/spark/sql/json/JsonRDD.scala | 4 +++- .../scala/org/apache/spark/sql/json/JsonSuite.scala | 11 ++++++++--- .../org/apache/spark/sql/json/TestJsonData.scala | 2 +- 3 files changed, 12 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala index a3d2a1c7a51f8..1c0b03c684f10 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala @@ -109,7 +109,9 @@ private[sql] object JsonRDD extends Logging { val newType = dataType match { case NullType => StringType case ArrayType(NullType, containsNull) => ArrayType(StringType, containsNull) - case struct: StructType => nullTypeToStringType(struct) + case ArrayType(struct: StructType, containsNull) => + ArrayType(nullTypeToStringType(struct), containsNull) + case struct: StructType =>nullTypeToStringType(struct) case other: DataType => other } StructField(fieldName, newType, nullable) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala index 75c0589eb208e..58b1e23891a3b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala @@ -213,7 +213,8 @@ class JsonSuite extends QueryTest { StructField("arrayOfStruct", ArrayType( StructType( StructField("field1", BooleanType, true) :: - StructField("field2", StringType, true) :: Nil)), true) :: + StructField("field2", StringType, true) :: + StructField("field3", StringType, true) :: Nil)), true) :: StructField("struct", StructType( StructField("field1", BooleanType, true) :: StructField("field2", DecimalType, true) :: Nil), true) :: @@ -263,8 +264,12 @@ class JsonSuite extends QueryTest { // Access elements of an array of structs. checkAnswer( - sql("select arrayOfStruct[0], arrayOfStruct[1], arrayOfStruct[2] from jsonTable"), - (true :: "str1" :: Nil, false :: null :: Nil, null) :: Nil + sql("select arrayOfStruct[0], arrayOfStruct[1], arrayOfStruct[2], arrayOfStruct[3] " + + "from jsonTable"), + (true :: "str1" :: null :: Nil, + false :: null :: null :: Nil, + null :: null :: null :: Nil, + null) :: Nil ) // Access a struct and fields inside of it. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala index d0180f3754f22..a88310b5f1b46 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala @@ -43,7 +43,7 @@ object TestJsonData { "arrayOfDouble":[1.2, 1.7976931348623157E308, 4.9E-324, 2.2250738585072014E-308], "arrayOfBoolean":[true, false, true], "arrayOfNull":[null, null, null, null], - "arrayOfStruct":[{"field1": true, "field2": "str1"}, {"field1": false}], + "arrayOfStruct":[{"field1": true, "field2": "str1"}, {"field1": false}, {"field3": null}], "arrayOfArray1":[[1, 2, 3], ["str1", "str2"]], "arrayOfArray2":[[1, 2, 3], [1.1, 2.1, 3.1]] }""" :: Nil) From 8b0188b43e63f3d7795684aa36b4bd6e9efb0129 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Fri, 8 Aug 2014 11:15:16 -0700 Subject: [PATCH 081/231] [SPARK-2877] [SQL] MetastoreRelation should use SparkClassLoader when creating the tableDesc JIRA: https://issues.apache.org/jira/browse/SPARK-2877 Author: Yin Huai Closes #1806 from yhuai/SPARK-2877 and squashes the following commits: 4142bcb [Yin Huai] Use Spark's classloader. (cherry picked from commit c874723fa844b49f057bb2434a12228b2f717e99) Signed-off-by: Michael Armbrust --- .../org/apache/spark/sql/hive/HiveMetastoreCatalog.scala | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 301cf51c00e2b..82e9c1a248626 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -19,8 +19,6 @@ package org.apache.spark.sql.hive import scala.util.parsing.combinator.RegexParsers -import org.apache.hadoop.fs.Path -import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.metastore.api.{FieldSchema, StorageDescriptor, SerDeInfo} import org.apache.hadoop.hive.metastore.api.{Table => TTable, Partition => TPartition} import org.apache.hadoop.hive.ql.metadata.{Hive, Partition, Table} @@ -39,6 +37,7 @@ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.columnar.InMemoryRelation import org.apache.spark.sql.hive.execution.HiveTableScan +import org.apache.spark.util.Utils /* Implicit conversions */ import scala.collection.JavaConversions._ @@ -288,7 +287,10 @@ private[hive] case class MetastoreRelation ) val tableDesc = new TableDesc( - Class.forName(hiveQlTable.getSerializationLib).asInstanceOf[Class[Deserializer]], + Class.forName( + hiveQlTable.getSerializationLib, + true, + Utils.getContextOrSparkClassLoader).asInstanceOf[Class[Deserializer]], hiveQlTable.getInputFormatClass, // The class of table should be org.apache.hadoop.hive.ql.metadata.Table because // getOutputFormatClass will use HiveFileFormatUtils.getOutputFormatSubstitute to From daa090f80942dc1476d86685fc1a3fb3392cf6ed Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Fri, 8 Aug 2014 11:23:58 -0700 Subject: [PATCH 082/231] [SPARK-2919] [SQL] Basic support for analyze command in HiveQl The command we will support is ``` ANALYZE TABLE tablename COMPUTE STATISTICS noscan ``` Other cases shown in https://cwiki.apache.org/confluence/display/Hive/StatsDev#StatsDev-ExistingTables will still be treated as Hive native commands. JIRA: https://issues.apache.org/jira/browse/SPARK-2919 Author: Yin Huai Closes #1848 from yhuai/sqlAnalyze and squashes the following commits: 0b79d36 [Yin Huai] Typo and format. c59d94b [Yin Huai] Support "ANALYZE TABLE tableName COMPUTE STATISTICS noscan". (cherry picked from commit 45d8f4deab50ae069ecde2201bd486d464a4501e) Signed-off-by: Michael Armbrust --- .../org/apache/spark/sql/hive/HiveQl.scala | 21 +++++++-- .../spark/sql/hive/HiveStrategies.scala | 2 + .../{DropTable.scala => commands.scala} | 26 +++++++++++ .../spark/sql/hive/StatisticsSuite.scala | 45 ++++++++++++++++++- 4 files changed, 89 insertions(+), 5 deletions(-) rename sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/{DropTable.scala => commands.scala} (72%) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index bc2fefafd58c8..05b2f5f6cd3f7 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -46,6 +46,8 @@ private[hive] case class AddFile(filePath: String) extends Command private[hive] case class DropTable(tableName: String, ifExists: Boolean) extends Command +private[hive] case class AnalyzeTable(tableName: String) extends Command + /** Provides a mapping from HiveQL statements to catalyst logical plans and expression trees. */ private[hive] object HiveQl { protected val nativeCommands = Seq( @@ -74,7 +76,6 @@ private[hive] object HiveQl { "TOK_CREATEFUNCTION", "TOK_DROPFUNCTION", - "TOK_ANALYZE", "TOK_ALTERDATABASE_PROPERTIES", "TOK_ALTERINDEX_PROPERTIES", "TOK_ALTERINDEX_REBUILD", @@ -92,7 +93,6 @@ private[hive] object HiveQl { "TOK_ALTERTABLE_SKEWED", "TOK_ALTERTABLE_TOUCH", "TOK_ALTERTABLE_UNARCHIVE", - "TOK_ANALYZE", "TOK_CREATEDATABASE", "TOK_CREATEFUNCTION", "TOK_CREATEINDEX", @@ -239,7 +239,6 @@ private[hive] object HiveQl { ShellCommand(sql.drop(1)) } else { val tree = getAst(sql) - if (nativeCommands contains tree.getText) { NativeCommand(sql) } else { @@ -387,6 +386,22 @@ private[hive] object HiveQl { ifExists) => val tableName = tableNameParts.map { case Token(p, Nil) => p }.mkString(".") DropTable(tableName, ifExists.nonEmpty) + // Support "ANALYZE TABLE tableNmae COMPUTE STATISTICS noscan" + case Token("TOK_ANALYZE", + Token("TOK_TAB", Token("TOK_TABNAME", tableNameParts) :: partitionSpec) :: + isNoscan) => + // Reference: + // https://cwiki.apache.org/confluence/display/Hive/StatsDev#StatsDev-ExistingTables + if (partitionSpec.nonEmpty) { + // Analyze partitions will be treated as a Hive native command. + NativePlaceholder + } else if (isNoscan.isEmpty) { + // If users do not specify "noscan", it will be treated as a Hive native command. + NativePlaceholder + } else { + val tableName = tableNameParts.map { case Token(p, Nil) => p }.mkString(".") + AnalyzeTable(tableName) + } // Just fake explain for any of the native commands. case Token("TOK_EXPLAIN", explainArgs) if noExplainCommands.contains(explainArgs.head.getText) => diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index 2175c5f3835a6..85d2496a34cfb 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -83,6 +83,8 @@ private[hive] trait HiveStrategies { case DropTable(tableName, ifExists) => execution.DropTable(tableName, ifExists) :: Nil + case AnalyzeTable(tableName) => execution.AnalyzeTable(tableName) :: Nil + case describe: logical.DescribeCommand => val resolvedTable = context.executePlan(describe.table).analyzed resolvedTable match { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DropTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala similarity index 72% rename from sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DropTable.scala rename to sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala index 9cd0c86c6c796..2985169da033c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DropTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala @@ -23,6 +23,32 @@ import org.apache.spark.sql.catalyst.expressions.Row import org.apache.spark.sql.execution.{Command, LeafNode} import org.apache.spark.sql.hive.HiveContext +/** + * :: DeveloperApi :: + * Analyzes the given table in the current database to generate statistics, which will be + * used in query optimizations. + * + * Right now, it only supports Hive tables and it only updates the size of a Hive table + * in the Hive metastore. + */ +@DeveloperApi +case class AnalyzeTable(tableName: String) extends LeafNode with Command { + + def hiveContext = sqlContext.asInstanceOf[HiveContext] + + def output = Seq.empty + + override protected[sql] lazy val sideEffectResult = { + hiveContext.analyze(tableName) + Seq.empty[Any] + } + + override def execute(): RDD[Row] = { + sideEffectResult + sparkContext.emptyRDD[Row] + } +} + /** * :: DeveloperApi :: * Drops a table from the metastore and removes it if it is cached. diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index bf5931bbf97ee..7c82964b5ecdc 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -19,13 +19,54 @@ package org.apache.spark.sql.hive import scala.reflect.ClassTag + import org.apache.spark.sql.{SQLConf, QueryTest} +import org.apache.spark.sql.catalyst.plans.logical.NativeCommand import org.apache.spark.sql.execution.{BroadcastHashJoin, ShuffledHashJoin} import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ class StatisticsSuite extends QueryTest { + test("parse analyze commands") { + def assertAnalyzeCommand(analyzeCommand: String, c: Class[_]) { + val parsed = HiveQl.parseSql(analyzeCommand) + val operators = parsed.collect { + case a: AnalyzeTable => a + case o => o + } + + assert(operators.size === 1) + if (operators(0).getClass() != c) { + fail( + s"""$analyzeCommand expected command: $c, but got ${operators(0)} + |parsed command: + |$parsed + """.stripMargin) + } + } + + assertAnalyzeCommand( + "ANALYZE TABLE Table1 COMPUTE STATISTICS", + classOf[NativeCommand]) + assertAnalyzeCommand( + "ANALYZE TABLE Table1 PARTITION(ds='2008-04-09', hr=11) COMPUTE STATISTICS", + classOf[NativeCommand]) + assertAnalyzeCommand( + "ANALYZE TABLE Table1 PARTITION(ds='2008-04-09', hr=11) COMPUTE STATISTICS noscan", + classOf[NativeCommand]) + assertAnalyzeCommand( + "ANALYZE TABLE Table1 PARTITION(ds, hr) COMPUTE STATISTICS", + classOf[NativeCommand]) + assertAnalyzeCommand( + "ANALYZE TABLE Table1 PARTITION(ds, hr) COMPUTE STATISTICS noscan", + classOf[NativeCommand]) + + assertAnalyzeCommand( + "ANALYZE TABLE Table1 COMPUTE STATISTICS nOscAn", + classOf[AnalyzeTable]) + } + test("analyze MetastoreRelations") { def queryTotalSize(tableName: String): BigInt = catalog.lookupRelation(None, tableName).statistics.sizeInBytes @@ -37,7 +78,7 @@ class StatisticsSuite extends QueryTest { assert(queryTotalSize("analyzeTable") === defaultSizeInBytes) - analyze("analyzeTable") + sql("ANALYZE TABLE analyzeTable COMPUTE STATISTICS noscan") assert(queryTotalSize("analyzeTable") === BigInt(11624)) @@ -66,7 +107,7 @@ class StatisticsSuite extends QueryTest { assert(queryTotalSize("analyzeTable_part") === defaultSizeInBytes) - analyze("analyzeTable_part") + sql("ANALYZE TABLE analyzeTable_part COMPUTE STATISTICS noscan") assert(queryTotalSize("analyzeTable_part") === BigInt(17436)) From e264503832a331c5b1344e8343ca9834db70bb11 Mon Sep 17 00:00:00 2001 From: chutium Date: Fri, 8 Aug 2014 13:31:08 -0700 Subject: [PATCH 083/231] [SPARK-2700] [SQL] Hidden files (such as .impala_insert_staging) should be filtered out by sqlContext.parquetFile Author: chutium Closes #1691 from chutium/SPARK-2700 and squashes the following commits: b76ae8c [chutium] [SPARK-2700] [SQL] fixed styling issue d75a8bd [chutium] [SPARK-2700] [SQL] Hidden files (such as .impala_insert_staging) should be filtered out by sqlContext.parquetFile (cherry picked from commit b7c89a7f0ca73153dce36e0f01b81a3947ee1189) Signed-off-by: Michael Armbrust --- .../scala/org/apache/spark/sql/parquet/ParquetTypes.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala index aaef1a1d474fe..2867dc0a8b1f9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala @@ -373,8 +373,9 @@ private[parquet] object ParquetTypesConverter extends Logging { } ParquetRelation.enableLogForwarding() - val children = fs.listStatus(path).filterNot { - _.getPath.getName == FileOutputCommitter.SUCCEEDED_FILE_NAME + val children = fs.listStatus(path).filterNot { status => + val name = status.getPath.getName + name(0) == '.' || name == FileOutputCommitter.SUCCEEDED_FILE_NAME } // NOTE (lian): Parquet "_metadata" file can be very slow if the file consists of lots of row From 8fba6de31c4be0b1d28a4fceb8164d52cd0ee712 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Fri, 8 Aug 2014 15:07:31 -0700 Subject: [PATCH 084/231] [SPARK-1997][MLLIB] update breeze to 0.9 0.9 dependences (this version doesn't depend on scalalogging and I excluded commons-math3 from its transitive dependencies): ~~~ +-org.scalanlp:breeze_2.10:0.9 [S] +-com.github.fommil.netlib:core:1.1.2 +-com.github.rwl:jtransforms:2.4.0 +-net.sf.opencsv:opencsv:2.3 +-net.sourceforge.f2j:arpack_combined_all:0.1 +-org.scalanlp:breeze-macros_2.10:0.3.1 [S] | +-org.scalamacros:quasiquotes_2.10:2.0.0 [S] | +-org.slf4j:slf4j-api:1.7.5 +-org.spire-math:spire_2.10:0.7.4 [S] +-org.scalamacros:quasiquotes_2.10:2.0.0 [S] | +-org.spire-math:spire-macros_2.10:0.7.4 [S] +-org.scalamacros:quasiquotes_2.10:2.0.0 [S] ~~~ Closes #1749 CC: witgo avati Author: Xiangrui Meng Closes #1857 from mengxr/breeze-0.9 and squashes the following commits: 7fc16b6 [Xiangrui Meng] don't know why but exclude a private method for mima dcc502e [Xiangrui Meng] update breeze to 0.9 (cherry picked from commit 74d6f62264babfc6045c21545552f0a2e6958155) Signed-off-by: Xiangrui Meng --- mllib/pom.xml | 2 +- .../org/apache/spark/mllib/linalg/distributed/RowMatrix.scala | 4 ++-- .../spark/mllib/linalg/distributed/RowMatrixSuite.scala | 2 +- project/MimaExcludes.scala | 4 ++++ 4 files changed, 8 insertions(+), 4 deletions(-) diff --git a/mllib/pom.xml b/mllib/pom.xml index d78fed794470c..d5c2e5ab54caa 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -57,7 +57,7 @@ org.scalanlp breeze_${scala.binary.version} - 0.7 + 0.9 diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala index 45486b2c7d82d..e76bc9fefff01 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala @@ -222,7 +222,7 @@ class RowMatrix( EigenValueDecomposition.symmetricEigs(v => G * v, n, k, tol, maxIter) case SVDMode.LocalLAPACK => val G = computeGramianMatrix().toBreeze.asInstanceOf[BDM[Double]] - val (uFull: BDM[Double], sigmaSquaresFull: BDV[Double], _) = brzSvd(G) + val brzSvd.SVD(uFull: BDM[Double], sigmaSquaresFull: BDV[Double], _) = brzSvd(G) (sigmaSquaresFull, uFull) case SVDMode.DistARPACK => require(k < n, s"k must be smaller than n in dist-eigs mode but got k=$k and n=$n.") @@ -338,7 +338,7 @@ class RowMatrix( val Cov = computeCovariance().toBreeze.asInstanceOf[BDM[Double]] - val (u: BDM[Double], _, _) = brzSvd(Cov) + val brzSvd.SVD(u: BDM[Double], _, _) = brzSvd(Cov) if (k == n) { Matrices.dense(n, k, u.data) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala index 325b817980f68..1d3a3221365cc 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala @@ -99,7 +99,7 @@ class RowMatrixSuite extends FunSuite with LocalSparkContext { for (mat <- Seq(denseMat, sparseMat)) { for (mode <- Seq("auto", "local-svd", "local-eigs", "dist-eigs")) { val localMat = mat.toBreeze() - val (localU, localSigma, localVt) = brzSvd(localMat) + val brzSvd.SVD(localU, localSigma, localVt) = brzSvd(localMat) val localV: BDM[Double] = localVt.t.toDenseMatrix for (k <- 1 to n) { val skip = (mode == "local-eigs" || mode == "dist-eigs") && k == n diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 537ca0dcf267d..b4653c72c10b5 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -110,6 +110,10 @@ object MimaExcludes { ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.util.LabelParser$"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.util.MulticlassLabelParser"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.util.MulticlassLabelParser$") + ) ++ + Seq ( // package-private classes removed in MLlib + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.regression.GeneralizedLinearAlgorithm.org$apache$spark$mllib$regression$GeneralizedLinearAlgorithm$$prependOne") ) case v if v.startsWith("1.0") => Seq( From dd11e4e4253897a18d3bb50f8293a580cbe578b3 Mon Sep 17 00:00:00 2001 From: GuoQiang Li Date: Fri, 8 Aug 2014 16:57:26 -0700 Subject: [PATCH 085/231] [SPARK-2897][SPARK-2920]TorrentBroadcast does use the serializer class specified in the spark option "spark.serializer" Author: GuoQiang Li Closes #1836 from witgo/SPARK-2897 and squashes the following commits: 23cdc5b [GuoQiang Li] review commit ada4fba [GuoQiang Li] TorrentBroadcast does not support broadcast compression fb91792 [GuoQiang Li] org.apache.spark.broadcast.TorrentBroadcast does use the serializer class specified in the spark option "spark.serializer" (cherry picked from commit ec79063fad44751a6689f5e58d47886babeaecff) Signed-off-by: Reynold Xin --- .../spark/broadcast/TorrentBroadcast.scala | 31 +++++++++++++++---- .../spark/broadcast/BroadcastSuite.scala | 10 ++++-- 2 files changed, 33 insertions(+), 8 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index 86731b684f441..fe73456ef8fad 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -17,14 +17,15 @@ package org.apache.spark.broadcast -import java.io.{ByteArrayInputStream, ObjectInputStream, ObjectOutputStream} +import java.io.{ByteArrayOutputStream, ByteArrayInputStream, InputStream, + ObjectInputStream, ObjectOutputStream, OutputStream} import scala.reflect.ClassTag import scala.util.Random import org.apache.spark.{Logging, SparkConf, SparkEnv, SparkException} +import org.apache.spark.io.CompressionCodec import org.apache.spark.storage.{BroadcastBlockId, StorageLevel} -import org.apache.spark.util.Utils /** * A [[org.apache.spark.broadcast.Broadcast]] implementation that uses a BitTorrent-like @@ -214,11 +215,15 @@ private[broadcast] object TorrentBroadcast extends Logging { private lazy val BLOCK_SIZE = conf.getInt("spark.broadcast.blockSize", 4096) * 1024 private var initialized = false private var conf: SparkConf = null + private var compress: Boolean = false + private var compressionCodec: CompressionCodec = null def initialize(_isDriver: Boolean, conf: SparkConf) { TorrentBroadcast.conf = conf // TODO: we might have to fix it in tests synchronized { if (!initialized) { + compress = conf.getBoolean("spark.broadcast.compress", true) + compressionCodec = CompressionCodec.createCodec(conf) initialized = true } } @@ -228,8 +233,13 @@ private[broadcast] object TorrentBroadcast extends Logging { initialized = false } - def blockifyObject[T](obj: T): TorrentInfo = { - val byteArray = Utils.serialize[T](obj) + def blockifyObject[T: ClassTag](obj: T): TorrentInfo = { + val bos = new ByteArrayOutputStream() + val out: OutputStream = if (compress) compressionCodec.compressedOutputStream(bos) else bos + val ser = SparkEnv.get.serializer.newInstance() + val serOut = ser.serializeStream(out) + serOut.writeObject[T](obj).close() + val byteArray = bos.toByteArray val bais = new ByteArrayInputStream(byteArray) var blockNum = byteArray.length / BLOCK_SIZE @@ -255,7 +265,7 @@ private[broadcast] object TorrentBroadcast extends Logging { info } - def unBlockifyObject[T]( + def unBlockifyObject[T: ClassTag]( arrayOfBlocks: Array[TorrentBlock], totalBytes: Int, totalBlocks: Int): T = { @@ -264,7 +274,16 @@ private[broadcast] object TorrentBroadcast extends Logging { System.arraycopy(arrayOfBlocks(i).byteArray, 0, retByteArray, i * BLOCK_SIZE, arrayOfBlocks(i).byteArray.length) } - Utils.deserialize[T](retByteArray, Thread.currentThread.getContextClassLoader) + + val in: InputStream = { + val arrIn = new ByteArrayInputStream(retByteArray) + if (compress) compressionCodec.compressedInputStream(arrIn) else arrIn + } + val ser = SparkEnv.get.serializer.newInstance() + val serIn = ser.deserializeStream(in) + val obj = serIn.readObject[T]() + serIn.close() + obj } /** diff --git a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala index 7c3d0208b195a..17c64455b2429 100644 --- a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala +++ b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala @@ -44,7 +44,10 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { test("Accessing HttpBroadcast variables in a local cluster") { val numSlaves = 4 - sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test", httpConf) + val conf = httpConf.clone + conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") + conf.set("spark.broadcast.compress", "true") + sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test", conf) val list = List[Int](1, 2, 3, 4) val broadcast = sc.broadcast(list) val results = sc.parallelize(1 to numSlaves).map(x => (x, broadcast.value.sum)) @@ -69,7 +72,10 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { test("Accessing TorrentBroadcast variables in a local cluster") { val numSlaves = 4 - sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test", torrentConf) + val conf = torrentConf.clone + conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") + conf.set("spark.broadcast.compress", "true") + sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test", conf) val list = List[Int](1, 2, 3, 4) val broadcast = sc.broadcast(list) val results = sc.parallelize(1 to numSlaves).map(x => (x, broadcast.value.sum)) From 3311da2f9efc5ff2c7d01273ac08f719b067d11d Mon Sep 17 00:00:00 2001 From: li-zhihui Date: Fri, 8 Aug 2014 22:52:56 -0700 Subject: [PATCH 086/231] [SPARK-2635] Fix race condition at SchedulerBackend.isReady in standalone mode In SPARK-1946(PR #900), configuration spark.scheduler.minRegisteredExecutorsRatio was introduced. However, in standalone mode, there is a race condition where isReady() can return true because totalExpectedExecutors has not been correctly set. Because expected executors is uncertain in standalone mode, the PR try to use CPU cores(--total-executor-cores) as expected resources to judge whether SchedulerBackend is ready. Author: li-zhihui Author: Li Zhihui Closes #1525 from li-zhihui/fixre4s and squashes the following commits: e9a630b [Li Zhihui] Rename variable totalExecutors and clean codes abf4860 [Li Zhihui] Push down variable totalExpectedResources to children classes ca54bd9 [li-zhihui] Format log with String interpolation 88c7dc6 [li-zhihui] Few codes and docs refactor 41cf47e [li-zhihui] Fix race condition at SchedulerBackend.isReady in standalone mode (cherry picked from commit 28dbae85aaf6842e22cd7465cb11cb34d58fc56d) Signed-off-by: Patrick Wendell --- .../CoarseGrainedSchedulerBackend.scala | 30 +++++++++---------- .../cluster/SparkDeploySchedulerBackend.scala | 6 +++- docs/configuration.md | 13 ++++---- .../cluster/YarnClientSchedulerBackend.scala | 9 ++++-- .../cluster/YarnClusterSchedulerBackend.scala | 17 +++++++---- 5 files changed, 43 insertions(+), 32 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 9f085eef46720..33500d967ebb1 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -47,19 +47,19 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A { // Use an atomic variable to track total number of cores in the cluster for simplicity and speed var totalCoreCount = new AtomicInteger(0) - var totalExpectedExecutors = new AtomicInteger(0) + var totalRegisteredExecutors = new AtomicInteger(0) val conf = scheduler.sc.conf private val timeout = AkkaUtils.askTimeout(conf) private val akkaFrameSize = AkkaUtils.maxFrameSizeBytes(conf) - // Submit tasks only after (registered executors / total expected executors) + // Submit tasks only after (registered resources / total expected resources) // is equal to at least this value, that is double between 0 and 1. - var minRegisteredRatio = conf.getDouble("spark.scheduler.minRegisteredExecutorsRatio", 0) - if (minRegisteredRatio > 1) minRegisteredRatio = 1 - // Whatever minRegisteredExecutorsRatio is arrived, submit tasks after the time(milliseconds). + var minRegisteredRatio = + math.min(1, conf.getDouble("spark.scheduler.minRegisteredResourcesRatio", 0)) + // Submit tasks after maxRegisteredWaitingTime milliseconds + // if minRegisteredRatio has not yet been reached val maxRegisteredWaitingTime = - conf.getInt("spark.scheduler.maxRegisteredExecutorsWaitingTime", 30000) + conf.getInt("spark.scheduler.maxRegisteredResourcesWaitingTime", 30000) val createTime = System.currentTimeMillis() - var ready = if (minRegisteredRatio <= 0) true else false class DriverActor(sparkProperties: Seq[(String, String)]) extends Actor { private val executorActor = new HashMap[String, ActorRef] @@ -94,12 +94,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A executorAddress(executorId) = sender.path.address addressToExecutorId(sender.path.address) = executorId totalCoreCount.addAndGet(cores) - if (executorActor.size >= totalExpectedExecutors.get() * minRegisteredRatio && !ready) { - ready = true - logInfo("SchedulerBackend is ready for scheduling beginning, registered executors: " + - executorActor.size + ", total expected executors: " + totalExpectedExecutors.get() + - ", minRegisteredExecutorsRatio: " + minRegisteredRatio) - } + totalRegisteredExecutors.addAndGet(1) makeOffers() } @@ -268,14 +263,17 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A } } + def sufficientResourcesRegistered(): Boolean = true + override def isReady(): Boolean = { - if (ready) { + if (sufficientResourcesRegistered) { + logInfo("SchedulerBackend is ready for scheduling beginning after " + + s"reached minRegisteredResourcesRatio: $minRegisteredRatio") return true } if ((System.currentTimeMillis() - createTime) >= maxRegisteredWaitingTime) { - ready = true logInfo("SchedulerBackend is ready for scheduling beginning after waiting " + - "maxRegisteredExecutorsWaitingTime: " + maxRegisteredWaitingTime) + s"maxRegisteredResourcesWaitingTime: $maxRegisteredWaitingTime(ms)") return true } false diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index a28446f6c8a6b..589dba2e40d20 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -36,6 +36,7 @@ private[spark] class SparkDeploySchedulerBackend( var shutdownCallback : (SparkDeploySchedulerBackend) => Unit = _ val maxCores = conf.getOption("spark.cores.max").map(_.toInt) + val totalExpectedCores = maxCores.getOrElse(0) override def start() { super.start() @@ -97,7 +98,6 @@ private[spark] class SparkDeploySchedulerBackend( override def executorAdded(fullId: String, workerId: String, hostPort: String, cores: Int, memory: Int) { - totalExpectedExecutors.addAndGet(1) logInfo("Granted executor ID %s on hostPort %s with %d cores, %s RAM".format( fullId, hostPort, cores, Utils.megabytesToString(memory))) } @@ -110,4 +110,8 @@ private[spark] class SparkDeploySchedulerBackend( logInfo("Executor %s removed: %s".format(fullId, message)) removeExecutor(fullId.split("/")(1), reason.toString) } + + override def sufficientResourcesRegistered(): Boolean = { + totalCoreCount.get() >= totalExpectedCores * minRegisteredRatio + } } diff --git a/docs/configuration.md b/docs/configuration.md index 4d27c5a918fe0..617a72a021f6e 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -825,21 +825,22 @@ Apart from these, the following properties are also available, and may be useful - spark.scheduler.minRegisteredExecutorsRatio + spark.scheduler.minRegisteredResourcesRatio 0 - The minimum ratio of registered executors (registered executors / total expected executors) + The minimum ratio of registered resources (registered resources / total expected resources) + (resources are executors in yarn mode, CPU cores in standalone mode) to wait for before scheduling begins. Specified as a double between 0 and 1. - Regardless of whether the minimum ratio of executors has been reached, + Regardless of whether the minimum ratio of resources has been reached, the maximum amount of time it will wait before scheduling begins is controlled by config - spark.scheduler.maxRegisteredExecutorsWaitingTime + spark.scheduler.maxRegisteredResourcesWaitingTime - spark.scheduler.maxRegisteredExecutorsWaitingTime + spark.scheduler.maxRegisteredResourcesWaitingTime 30000 - Maximum amount of time to wait for executors to register before scheduling begins + Maximum amount of time to wait for resources to register before scheduling begins (in milliseconds). diff --git a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala index f8fb96b312f23..833e249f9f612 100644 --- a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala +++ b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala @@ -30,15 +30,15 @@ private[spark] class YarnClientSchedulerBackend( extends CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem) with Logging { - if (conf.getOption("spark.scheduler.minRegisteredExecutorsRatio").isEmpty) { + if (conf.getOption("spark.scheduler.minRegisteredResourcesRatio").isEmpty) { minRegisteredRatio = 0.8 - ready = false } var client: Client = null var appId: ApplicationId = null var checkerThread: Thread = null var stopping: Boolean = false + var totalExpectedExecutors = 0 private[spark] def addArg(optionName: String, envVar: String, sysProp: String, arrayBuf: ArrayBuffer[String]) { @@ -84,7 +84,7 @@ private[spark] class YarnClientSchedulerBackend( logDebug("ClientArguments called with: " + argsArrayBuf) val args = new ClientArguments(argsArrayBuf.toArray, conf) - totalExpectedExecutors.set(args.numExecutors) + totalExpectedExecutors = args.numExecutors client = new Client(args, conf) appId = client.runApp() waitForApp() @@ -150,4 +150,7 @@ private[spark] class YarnClientSchedulerBackend( logInfo("Stopped") } + override def sufficientResourcesRegistered(): Boolean = { + totalRegisteredExecutors.get() >= totalExpectedExecutors * minRegisteredRatio + } } diff --git a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala index 0ad1794d19538..55665220a6f96 100644 --- a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala +++ b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala @@ -27,19 +27,24 @@ private[spark] class YarnClusterSchedulerBackend( sc: SparkContext) extends CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem) { - if (conf.getOption("spark.scheduler.minRegisteredExecutorsRatio").isEmpty) { + var totalExpectedExecutors = 0 + + if (conf.getOption("spark.scheduler.minRegisteredResourcesRatio").isEmpty) { minRegisteredRatio = 0.8 - ready = false } override def start() { super.start() - var numExecutors = ApplicationMasterArguments.DEFAULT_NUMBER_EXECUTORS + totalExpectedExecutors = ApplicationMasterArguments.DEFAULT_NUMBER_EXECUTORS if (System.getenv("SPARK_EXECUTOR_INSTANCES") != null) { - numExecutors = IntParam.unapply(System.getenv("SPARK_EXECUTOR_INSTANCES")).getOrElse(numExecutors) + totalExpectedExecutors = IntParam.unapply(System.getenv("SPARK_EXECUTOR_INSTANCES")) + .getOrElse(totalExpectedExecutors) } // System property can override environment variable. - numExecutors = sc.getConf.getInt("spark.executor.instances", numExecutors) - totalExpectedExecutors.set(numExecutors) + totalExpectedExecutors = sc.getConf.getInt("spark.executor.instances", totalExpectedExecutors) + } + + override def sufficientResourcesRegistered(): Boolean = { + totalRegisteredExecutors.get() >= totalExpectedExecutors * minRegisteredRatio } } From 71fcd2ea1e2561c41b40fdd2f53b334b198368cf Mon Sep 17 00:00:00 2001 From: Chandan Kumar Date: Sat, 9 Aug 2014 00:45:54 -0700 Subject: [PATCH 087/231] [SPARK-2861] Fix Doc comment of histogram method Tested and ready to merge. Author: Chandan Kumar Closes #1786 from nrchandan/spark-2861 and squashes the following commits: cb0bc1e [Chandan Kumar] [SPARK-2861] Fix a typo in the histogram doc comment 6a2a71b [Chandan Kumar] SPARK-2861. Fix Doc comment of histogram method (cherry picked from commit b431e6747f410aaf9624585920adc1f303159861) Signed-off-by: Patrick Wendell --- .../scala/org/apache/spark/rdd/DoubleRDDFunctions.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala index 9ca971c8a4c27..f233544d128f5 100644 --- a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala @@ -119,11 +119,11 @@ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable { /** * Compute a histogram using the provided buckets. The buckets are all open - * to the left except for the last which is closed + * to the right except for the last which is closed * e.g. for the array * [1, 10, 20, 50] the buckets are [1, 10) [10, 20) [20, 50] - * e.g 1<=x<10 , 10<=x<20, 20<=x<50 - * And on the input of 1 and 50 we would have a histogram of 1, 0, 0 + * e.g 1<=x<10 , 10<=x<20, 20<=x<=50 + * And on the input of 1 and 50 we would have a histogram of 1, 0, 1 * * Note: if your histogram is evenly spaced (e.g. [0, 10, 20, 30]) this can be switched * from an O(log n) inseration to O(1) per element. (where n = # buckets) if you set evenBuckets From 4a7f3ef882700ea7ec005dba77480c559565943f Mon Sep 17 00:00:00 2001 From: Chris Cope Date: Sat, 9 Aug 2014 20:58:56 -0700 Subject: [PATCH 088/231] [SPARK-1766] sorted functions to meet pedantic requirements Pedantry is underrated Author: Chris Cope Closes #1859 from copester/master and squashes the following commits: 0fb4499 [Chris Cope] [SPARK-1766] sorted functions to meet pedantic requirements --- .../apache/spark/rdd/PairRDDFunctions.scala | 38 +++++++++---------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index 93af50c0a9cd1..5dd6472b0776c 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -237,6 +237,25 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) combineByKey[V]((v: V) => v, func, func, partitioner) } + /** + * Merge the values for each key using an associative reduce function. This will also perform + * the merging locally on each mapper before sending results to a reducer, similarly to a + * "combiner" in MapReduce. Output will be hash-partitioned with numPartitions partitions. + */ + def reduceByKey(func: (V, V) => V, numPartitions: Int): RDD[(K, V)] = { + reduceByKey(new HashPartitioner(numPartitions), func) + } + + /** + * Merge the values for each key using an associative reduce function. This will also perform + * the merging locally on each mapper before sending results to a reducer, similarly to a + * "combiner" in MapReduce. Output will be hash-partitioned with the existing partitioner/ + * parallelism level. + */ + def reduceByKey(func: (V, V) => V): RDD[(K, V)] = { + reduceByKey(defaultPartitioner(self), func) + } + /** * Merge the values for each key using an associative reduce function, but return the results * immediately to the master as a Map. This will also perform the merging locally on each mapper @@ -374,15 +393,6 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) countApproxDistinctByKey(relativeSD, defaultPartitioner(self)) } - /** - * Merge the values for each key using an associative reduce function. This will also perform - * the merging locally on each mapper before sending results to a reducer, similarly to a - * "combiner" in MapReduce. Output will be hash-partitioned with numPartitions partitions. - */ - def reduceByKey(func: (V, V) => V, numPartitions: Int): RDD[(K, V)] = { - reduceByKey(new HashPartitioner(numPartitions), func) - } - /** * Group the values for each key in the RDD into a single sequence. Allows controlling the * partitioning of the resulting key-value pair RDD by passing a Partitioner. @@ -482,16 +492,6 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) combineByKey(createCombiner, mergeValue, mergeCombiners, defaultPartitioner(self)) } - /** - * Merge the values for each key using an associative reduce function. This will also perform - * the merging locally on each mapper before sending results to a reducer, similarly to a - * "combiner" in MapReduce. Output will be hash-partitioned with the existing partitioner/ - * parallelism level. - */ - def reduceByKey(func: (V, V) => V): RDD[(K, V)] = { - reduceByKey(defaultPartitioner(self), func) - } - /** * Group the values for each key in the RDD into a single sequence. Hash-partitions the * resulting RDD with the existing partitioner/parallelism level. From ba223b8ecf00df4acf588f3a91fd9860f5e1b135 Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Sat, 9 Aug 2014 21:10:43 -0700 Subject: [PATCH 089/231] [SPARK-2894] spark-shell doesn't accept flags As sryza reported, spark-shell doesn't accept any flags. The root cause is wrong usage of spark-submit in spark-shell and it come to the surface by #1801 Author: Kousuke Saruta Author: Cheng Lian Closes #1715, Closes #1864, and Closes #1861 Closes #1825 from sarutak/SPARK-2894 and squashes the following commits: 47f3510 [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into SPARK-2894 2c899ed [Kousuke Saruta] Removed useless code from java_gateway.py 98287ed [Kousuke Saruta] Removed useless code from java_gateway.py 513ad2e [Kousuke Saruta] Modified util.sh to enable to use option including white spaces 28a374e [Kousuke Saruta] Modified java_gateway.py to recognize arguments 5afc584 [Cheng Lian] Filter out spark-submit options when starting Python gateway e630d19 [Cheng Lian] Fixing pyspark and spark-shell CLI options --- bin/pyspark | 18 ++++-- bin/spark-shell | 20 +++++-- bin/utils.sh | 59 +++++++++++++++++++ .../spark/deploy/SparkSubmitArguments.scala | 4 ++ dev/merge_spark_pr.py | 2 + python/pyspark/java_gateway.py | 2 +- 6 files changed, 94 insertions(+), 11 deletions(-) create mode 100644 bin/utils.sh diff --git a/bin/pyspark b/bin/pyspark index 39a20e2a24a3c..01d42025c978e 100755 --- a/bin/pyspark +++ b/bin/pyspark @@ -23,12 +23,18 @@ FWDIR="$(cd `dirname $0`/..; pwd)" # Export this as SPARK_HOME export SPARK_HOME="$FWDIR" +source $FWDIR/bin/utils.sh + SCALA_VERSION=2.10 -if [[ "$@" = *--help ]] || [[ "$@" = *-h ]]; then +function usage() { echo "Usage: ./bin/pyspark [options]" 1>&2 $FWDIR/bin/spark-submit --help 2>&1 | grep -v Usage 1>&2 exit 0 +} + +if [[ "$@" = *--help ]] || [[ "$@" = *-h ]]; then + usage fi # Exit if the user hasn't compiled Spark @@ -66,10 +72,11 @@ fi # Build up arguments list manually to preserve quotes and backslashes. # We export Spark submit arguments as an environment variable because shell.py must run as a # PYTHONSTARTUP script, which does not take in arguments. This is required for IPython notebooks. - +SUBMIT_USAGE_FUNCTION=usage +gatherSparkSubmitOpts "$@" PYSPARK_SUBMIT_ARGS="" whitespace="[[:space:]]" -for i in "$@"; do +for i in "${SUBMISSION_OPTS[@]}"; do if [[ $i =~ \" ]]; then i=$(echo $i | sed 's/\"/\\\"/g'); fi if [[ $i =~ $whitespace ]]; then i=\"$i\"; fi PYSPARK_SUBMIT_ARGS="$PYSPARK_SUBMIT_ARGS $i" @@ -90,7 +97,10 @@ fi if [[ "$1" =~ \.py$ ]]; then echo -e "\nWARNING: Running python applications through ./bin/pyspark is deprecated as of Spark 1.0." 1>&2 echo -e "Use ./bin/spark-submit \n" 1>&2 - exec $FWDIR/bin/spark-submit "$@" + primary=$1 + shift + gatherSparkSubmitOpts "$@" + exec $FWDIR/bin/spark-submit "${SUBMISSION_OPTS[@]}" $primary "${APPLICATION_OPTS[@]}" else # Only use ipython if no command line arguments were provided [SPARK-1134] if [[ "$IPYTHON" = "1" ]]; then diff --git a/bin/spark-shell b/bin/spark-shell index 756c8179d12b6..8b7ccd7439551 100755 --- a/bin/spark-shell +++ b/bin/spark-shell @@ -31,13 +31,21 @@ set -o posix ## Global script variables FWDIR="$(cd `dirname $0`/..; pwd)" +function usage() { + echo "Usage: ./bin/spark-shell [options]" + $FWDIR/bin/spark-submit --help 2>&1 | grep -v Usage 1>&2 + exit 0 +} + if [[ "$@" = *--help ]] || [[ "$@" = *-h ]]; then - echo "Usage: ./bin/spark-shell [options]" - $FWDIR/bin/spark-submit --help 2>&1 | grep -v Usage 1>&2 - exit 0 + usage fi -function main(){ +source $FWDIR/bin/utils.sh +SUBMIT_USAGE_FUNCTION=usage +gatherSparkSubmitOpts "$@" + +function main() { if $cygwin; then # Workaround for issue involving JLine and Cygwin # (see http://sourceforge.net/p/jline/bugs/40/). @@ -46,11 +54,11 @@ function main(){ # (see https://github.com/sbt/sbt/issues/562). stty -icanon min 1 -echo > /dev/null 2>&1 export SPARK_SUBMIT_OPTS="$SPARK_SUBMIT_OPTS -Djline.terminal=unix" - $FWDIR/bin/spark-submit --class org.apache.spark.repl.Main spark-shell "$@" + $FWDIR/bin/spark-submit --class org.apache.spark.repl.Main "${SUBMISSION_OPTS[@]}" spark-shell "${APPLICATION_OPTS[@]}" stty icanon echo > /dev/null 2>&1 else export SPARK_SUBMIT_OPTS - $FWDIR/bin/spark-submit --class org.apache.spark.repl.Main spark-shell "$@" + $FWDIR/bin/spark-submit --class org.apache.spark.repl.Main "${SUBMISSION_OPTS[@]}" spark-shell "${APPLICATION_OPTS[@]}" fi } diff --git a/bin/utils.sh b/bin/utils.sh new file mode 100644 index 0000000000000..0804b1ed9f231 --- /dev/null +++ b/bin/utils.sh @@ -0,0 +1,59 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Gather all all spark-submit options into SUBMISSION_OPTS +function gatherSparkSubmitOpts() { + + if [ -z "$SUBMIT_USAGE_FUNCTION" ]; then + echo "Function for printing usage of $0 is not set." 1>&2 + echo "Please set usage function to shell variable 'SUBMIT_USAGE_FUNCTION' in $0" 1>&2 + exit 1 + fi + + # NOTE: If you add or remove spark-sumbmit options, + # modify NOT ONLY this script but also SparkSubmitArgument.scala + SUBMISSION_OPTS=() + APPLICATION_OPTS=() + while (($#)); do + case "$1" in + --master | --deploy-mode | --class | --name | --jars | --py-files | --files | \ + --conf | --properties-file | --driver-memory | --driver-java-options | \ + --driver-library-path | --driver-class-path | --executor-memory | --driver-cores | \ + --total-executor-cores | --executor-cores | --queue | --num-executors | --archives) + if [[ $# -lt 2 ]]; then + "$SUBMIT_USAGE_FUNCTION" + exit 1; + fi + SUBMISSION_OPTS+=("$1"); shift + SUBMISSION_OPTS+=("$1"); shift + ;; + + --verbose | -v | --supervise) + SUBMISSION_OPTS+=("$1"); shift + ;; + + *) + APPLICATION_OPTS+=("$1"); shift + ;; + esac + done + + export SUBMISSION_OPTS + export APPLICATION_OPTS +} diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index c21f1529a1837..d545f58c5da7e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -224,6 +224,10 @@ private[spark] class SparkSubmitArguments(args: Seq[String]) { // Delineates parsing of Spark options from parsing of user options. parse(opts) + /** + * NOTE: If you add or remove spark-submit options, + * modify NOT ONLY this file but also utils.sh + */ def parse(opts: Seq[String]): Unit = opts match { case ("--name") :: value :: tail => name = value diff --git a/dev/merge_spark_pr.py b/dev/merge_spark_pr.py index 53df9b5a3f1d5..d48c8bde12905 100755 --- a/dev/merge_spark_pr.py +++ b/dev/merge_spark_pr.py @@ -74,8 +74,10 @@ def fail(msg): def run_cmd(cmd): if isinstance(cmd, list): + print " ".join(cmd) return subprocess.check_output(cmd) else: + print cmd return subprocess.check_output(cmd.split(" ")) diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py index 37386ab0d7d49..c7f7c1fe591b0 100644 --- a/python/pyspark/java_gateway.py +++ b/python/pyspark/java_gateway.py @@ -39,7 +39,7 @@ def launch_gateway(): submit_args = os.environ.get("PYSPARK_SUBMIT_ARGS") submit_args = submit_args if submit_args is not None else "" submit_args = shlex.split(submit_args) - command = [os.path.join(SPARK_HOME, script), "pyspark-shell"] + submit_args + command = [os.path.join(SPARK_HOME, script)] + submit_args + ["pyspark-shell"] if not on_windows: # Don't send ctrl-c / SIGINT to the Java gateway: def preexec_func(): From e8f8e5f4a2f2cbda695608cc3b0e13fcfa66d487 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sat, 9 Aug 2014 22:05:36 -0700 Subject: [PATCH 090/231] Updated Spark SQL README to include the hive-thriftserver module Author: Reynold Xin Closes #1867 from rxin/sql-readme and squashes the following commits: 42a5307 [Reynold Xin] Updated Spark SQL README to include the hive-thriftserver module (cherry picked from commit 5b6585de6b939837d5bdc4b1a44634301949add6) Signed-off-by: Reynold Xin --- sql/README.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/README.md b/sql/README.md index 14d5555f0c713..31f9152344086 100644 --- a/sql/README.md +++ b/sql/README.md @@ -3,10 +3,11 @@ Spark SQL This module provides support for executing relational queries expressed in either SQL or a LINQ-like Scala DSL. -Spark SQL is broken up into three subprojects: +Spark SQL is broken up into four subprojects: - Catalyst (sql/catalyst) - An implementation-agnostic framework for manipulating trees of relational operators and expressions. - Execution (sql/core) - A query planner / execution engine for translating Catalyst’s logical query plans into Spark RDDs. This component also includes a new public interface, SQLContext, that allows users to execute SQL or LINQ statements against existing RDDs and Parquet files. - Hive Support (sql/hive) - Includes an extension of SQLContext called HiveContext that allows users to write queries using a subset of HiveQL and access data from a Hive Metastore using Hive SerDes. There are also wrappers that allows users to run queries that include Hive UDFs, UDAFs, and UDTFs. + - HiveServer and CLI support (sql/hive-thriftserver) - Includes support for the SQL CLI (bin/spark-sql) and a HiveServer2 (for JDBC/ODBC) compatible server. Other dependencies for developers From 076ddda6a9c0b4ca4f167fdf59e9a99fc0fce81f Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sat, 9 Aug 2014 23:06:54 -0700 Subject: [PATCH 091/231] Turn UpdateBlockInfo into case class. This helps us log UpdateBlockInfo properly once #1870 is merged. Author: Reynold Xin Closes #1872 from rxin/UpdateBlockInfo and squashes the following commits: 0cee1c2 [Reynold Xin] Turn UpdateBlockInfo into case class. (cherry picked from commit 482c5afbf6f3f12ac23851300a33249b26ddff3c) Signed-off-by: Reynold Xin --- .../spark/storage/BlockManagerMessages.scala | 20 +------------------ 1 file changed, 1 insertion(+), 19 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala index 10b65286fb7db..2ba16b8476600 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala @@ -53,7 +53,7 @@ private[spark] object BlockManagerMessages { sender: ActorRef) extends ToBlockManagerMaster - class UpdateBlockInfo( + case class UpdateBlockInfo( var blockManagerId: BlockManagerId, var blockId: BlockId, var storageLevel: StorageLevel, @@ -84,24 +84,6 @@ private[spark] object BlockManagerMessages { } } - object UpdateBlockInfo { - def apply( - blockManagerId: BlockManagerId, - blockId: BlockId, - storageLevel: StorageLevel, - memSize: Long, - diskSize: Long, - tachyonSize: Long): UpdateBlockInfo = { - new UpdateBlockInfo(blockManagerId, blockId, storageLevel, memSize, diskSize, tachyonSize) - } - - // For pattern-matching - def unapply(h: UpdateBlockInfo) - : Option[(BlockManagerId, BlockId, StorageLevel, Long, Long, Long)] = { - Some((h.blockManagerId, h.blockId, h.storageLevel, h.memSize, h.diskSize, h.tachyonSize)) - } - } - case class GetLocations(blockId: BlockId) extends ToBlockManagerMaster case class GetLocationsMultipleBlockIds(blockIds: Array[BlockId]) extends ToBlockManagerMaster From bb23b118eb32db67779ff010fa33273e46b123f2 Mon Sep 17 00:00:00 2001 From: GuoQiang Li Date: Sun, 10 Aug 2014 12:12:22 -0700 Subject: [PATCH 092/231] Remove extra semicolon in Task.scala Author: GuoQiang Li Closes #1876 from witgo/remove_semicolon_in_Task_scala and squashes the following commits: c6ea732 [GuoQiang Li] Remove extra semicolon in Task.scala (cherry picked from commit 3570119c34ab8d61507e7703a171b742fb0957d4) Signed-off-by: Reynold Xin --- core/src/main/scala/org/apache/spark/scheduler/Task.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 5c5e421404a21..cbe0bc0bcb0a5 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -46,7 +46,7 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex final def run(attemptId: Long): T = { context = new TaskContext(stageId, partitionId, attemptId, runningLocally = false) - context.taskMetrics.hostname = Utils.localHostName(); + context.taskMetrics.hostname = Utils.localHostName() taskThread = Thread.currentThread() if (_killed) { kill(interruptThread = false) From 92daffed4c17e373a06333c85124075d0fd18f0c Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Sun, 10 Aug 2014 13:00:38 -0700 Subject: [PATCH 093/231] [SPARK-2898] [PySpark] fix bugs in deamon.py 1. do not use signal handler for SIGCHILD, it's easy to cause deadlock 2. handle EINTR during accept() 3. pass errno into JVM 4. handle EAGAIN during fork() Now, it can pass 50k tasks tests in 180 seconds. Author: Davies Liu Closes #1842 from davies/qa and squashes the following commits: f0ea451 [Davies Liu] fix lint 03a2e8c [Davies Liu] cleanup dead children every seconds 32cb829 [Davies Liu] fix lint 0cd0817 [Davies Liu] fix bugs in deamon.py (cherry picked from commit 28dcbb531ae57dc50f15ad9df6c31022731669c9) Signed-off-by: Josh Rosen --- .../api/python/PythonWorkerFactory.scala | 2 +- python/pyspark/daemon.py | 78 +++++++++++-------- 2 files changed, 48 insertions(+), 32 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala index 7af260d0b7f26..bf716a8ab025b 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala @@ -68,7 +68,7 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String val socket = new Socket(daemonHost, daemonPort) val pid = new DataInputStream(socket.getInputStream).readInt() if (pid < 0) { - throw new IllegalStateException("Python daemon failed to launch worker") + throw new IllegalStateException("Python daemon failed to launch worker with code " + pid) } daemonWorkers.put(socket, pid) socket diff --git a/python/pyspark/daemon.py b/python/pyspark/daemon.py index e73538baf0b93..22ab8d30c0ae3 100644 --- a/python/pyspark/daemon.py +++ b/python/pyspark/daemon.py @@ -22,7 +22,8 @@ import socket import sys import traceback -from errno import EINTR, ECHILD +import time +from errno import EINTR, ECHILD, EAGAIN from socket import AF_INET, SOCK_STREAM, SOMAXCONN from signal import SIGHUP, SIGTERM, SIGCHLD, SIG_DFL, SIG_IGN from pyspark.worker import main as worker_main @@ -80,6 +81,17 @@ def waitSocketClose(sock): os._exit(compute_real_exit_code(exit_code)) +# Cleanup zombie children +def cleanup_dead_children(): + try: + while True: + pid, _ = os.waitpid(0, os.WNOHANG) + if not pid: + break + except: + pass + + def manager(): # Create a new process group to corral our children os.setpgid(0, 0) @@ -102,29 +114,21 @@ def handle_sigterm(*args): signal.signal(SIGTERM, handle_sigterm) # Gracefully exit on SIGTERM signal.signal(SIGHUP, SIG_IGN) # Don't die on SIGHUP - # Cleanup zombie children - def handle_sigchld(*args): - try: - pid, status = os.waitpid(0, os.WNOHANG) - if status != 0: - msg = "worker %s crashed abruptly with exit status %s" % (pid, status) - print >> sys.stderr, msg - except EnvironmentError as err: - if err.errno not in (ECHILD, EINTR): - raise - signal.signal(SIGCHLD, handle_sigchld) - # Initialization complete sys.stdout.close() try: while True: try: - ready_fds = select.select([0, listen_sock], [], [])[0] + ready_fds = select.select([0, listen_sock], [], [], 1)[0] except select.error as ex: if ex[0] == EINTR: continue else: raise + + # cleanup in signal handler will cause deadlock + cleanup_dead_children() + if 0 in ready_fds: try: worker_pid = read_int(sys.stdin) @@ -137,29 +141,41 @@ def handle_sigchld(*args): pass # process already died if listen_sock in ready_fds: - sock, addr = listen_sock.accept() + try: + sock, _ = listen_sock.accept() + except OSError as e: + if e.errno == EINTR: + continue + raise + # Launch a worker process try: pid = os.fork() - if pid == 0: - listen_sock.close() - try: - worker(sock) - except: - traceback.print_exc() - os._exit(1) - else: - os._exit(0) + except OSError as e: + if e.errno in (EAGAIN, EINTR): + time.sleep(1) + pid = os.fork() # error here will shutdown daemon else: + outfile = sock.makefile('w') + write_int(e.errno, outfile) # Signal that the fork failed + outfile.flush() + outfile.close() sock.close() - - except OSError as e: - print >> sys.stderr, "Daemon failed to fork PySpark worker: %s" % e - outfile = os.fdopen(os.dup(sock.fileno()), "a+", 65536) - write_int(-1, outfile) # Signal that the fork failed - outfile.flush() - outfile.close() + continue + + if pid == 0: + # in child process + listen_sock.close() + try: + worker(sock) + except: + traceback.print_exc() + os._exit(1) + else: + os._exit(0) + else: sock.close() + finally: shutdown(1) From 3def842d941a29ca75e8e6c447952655654dc44d Mon Sep 17 00:00:00 2001 From: Doris Xin Date: Sun, 10 Aug 2014 16:31:07 -0700 Subject: [PATCH 094/231] [SPARK-2937] Separate out samplyByKeyExact as its own API in PairRDDFunction To enable Python consistency and `Experimental` label of the `sampleByKeyExact` API. Author: Doris Xin Author: Xiangrui Meng Closes #1866 from dorx/stratified and squashes the following commits: 0ad97b2 [Doris Xin] reviewer comments. 2948aae [Doris Xin] remove unrelated changes e990325 [Doris Xin] Merge branch 'master' into stratified 555a3f9 [Doris Xin] separate out sampleByKeyExact as its own API 616e55c [Doris Xin] merge master 245439e [Doris Xin] moved minSamplingRate to getUpperBound eaf5771 [Doris Xin] bug fixes. 17a381b [Doris Xin] fixed a merge issue and a failed unit ea7d27f [Doris Xin] merge master b223529 [Xiangrui Meng] use approx bounds for poisson fix poisson mean for waitlisting add unit tests for Java b3013a4 [Xiangrui Meng] move math3 back to test scope eecee5f [Doris Xin] Merge branch 'master' into stratified f4c21f3 [Doris Xin] Reviewer comments a10e68d [Doris Xin] style fix a2bf756 [Doris Xin] Merge branch 'master' into stratified 680b677 [Doris Xin] use mapPartitionWithIndex instead 9884a9f [Doris Xin] style fix bbfb8c9 [Doris Xin] Merge branch 'master' into stratified ee9d260 [Doris Xin] addressed reviewer comments 6b5b10b [Doris Xin] Merge branch 'master' into stratified 254e03c [Doris Xin] minor fixes and Java API. 4ad516b [Doris Xin] remove unused imports from PairRDDFunctions bd9dc6e [Doris Xin] unit bug and style violation fixed 1fe1cff [Doris Xin] Changed fractionByKey to a map to enable arg check 944a10c [Doris Xin] [SPARK-2145] Add lower bound on sampling rate 0214a76 [Doris Xin] cleanUp 90d94c0 [Doris Xin] merge master 9e74ab5 [Doris Xin] Separated out most of the logic in sampleByKey 7327611 [Doris Xin] merge master 50581fc [Doris Xin] added a TODO for logging in python 46f6c8c [Doris Xin] fixed the NPE caused by closures being cleaned before being passed into the aggregate function 7e1a481 [Doris Xin] changed the permission on SamplingUtil 1d413ce [Doris Xin] fixed checkstyle issues 9ee94ee [Doris Xin] [SPARK-2082] stratified sampling in PairRDDFunctions that guarantees exact sample size e3fd6a6 [Doris Xin] Merge branch 'master' into takeSample 7cab53a [Doris Xin] fixed import bug in rdd.py ffea61a [Doris Xin] SPARK-1939: Refactor takeSample method in RDD 1441977 [Doris Xin] SPARK-1939 Refactor takeSample method in RDD to use ScaSRS (cherry picked from commit b715aa0c8090cd57158ead2a1b35632cb98a6277) Signed-off-by: Xiangrui Meng --- .../apache/spark/api/java/JavaPairRDD.scala | 68 +++--- .../apache/spark/rdd/PairRDDFunctions.scala | 51 +++-- .../java/org/apache/spark/JavaAPISuite.java | 20 +- .../spark/rdd/PairRDDFunctionsSuite.scala | 205 +++++++++++------- 4 files changed, 216 insertions(+), 128 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala index 76d4193e96aea..feeb6c02caa78 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala @@ -133,68 +133,62 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) * Return a subset of this RDD sampled by key (via stratified sampling). * * Create a sample of this RDD using variable sampling rates for different keys as specified by - * `fractions`, a key to sampling rate map. - * - * If `exact` is set to false, create the sample via simple random sampling, with one pass - * over the RDD, to produce a sample of size that's approximately equal to the sum of - * math.ceil(numItems * samplingRate) over all key values; otherwise, use additional passes over - * the RDD to create a sample size that's exactly equal to the sum of + * `fractions`, a key to sampling rate map, via simple random sampling with one pass over the + * RDD, to produce a sample of size that's approximately equal to the sum of * math.ceil(numItems * samplingRate) over all key values. */ def sampleByKey(withReplacement: Boolean, fractions: JMap[K, Double], - exact: Boolean, seed: Long): JavaPairRDD[K, V] = - new JavaPairRDD[K, V](rdd.sampleByKey(withReplacement, fractions, exact, seed)) + new JavaPairRDD[K, V](rdd.sampleByKey(withReplacement, fractions, seed)) /** * Return a subset of this RDD sampled by key (via stratified sampling). * * Create a sample of this RDD using variable sampling rates for different keys as specified by - * `fractions`, a key to sampling rate map. - * - * If `exact` is set to false, create the sample via simple random sampling, with one pass - * over the RDD, to produce a sample of size that's approximately equal to the sum of - * math.ceil(numItems * samplingRate) over all key values; otherwise, use additional passes over - * the RDD to create a sample size that's exactly equal to the sum of + * `fractions`, a key to sampling rate map, via simple random sampling with one pass over the + * RDD, to produce a sample of size that's approximately equal to the sum of * math.ceil(numItems * samplingRate) over all key values. * - * Use Utils.random.nextLong as the default seed for the random number generator + * Use Utils.random.nextLong as the default seed for the random number generator. */ def sampleByKey(withReplacement: Boolean, - fractions: JMap[K, Double], - exact: Boolean): JavaPairRDD[K, V] = - sampleByKey(withReplacement, fractions, exact, Utils.random.nextLong) + fractions: JMap[K, Double]): JavaPairRDD[K, V] = + sampleByKey(withReplacement, fractions, Utils.random.nextLong) /** - * Return a subset of this RDD sampled by key (via stratified sampling). - * - * Create a sample of this RDD using variable sampling rates for different keys as specified by - * `fractions`, a key to sampling rate map. + * ::Experimental:: + * Return a subset of this RDD sampled by key (via stratified sampling) containing exactly + * math.ceil(numItems * samplingRate) for each stratum (group of pairs with the same key). * - * Produce a sample of size that's approximately equal to the sum of - * math.ceil(numItems * samplingRate) over all key values with one pass over the RDD via - * simple random sampling. + * This method differs from [[sampleByKey]] in that we make additional passes over the RDD to + * create a sample size that's exactly equal to the sum of math.ceil(numItems * samplingRate) + * over all key values with a 99.99% confidence. When sampling without replacement, we need one + * additional pass over the RDD to guarantee sample size; when sampling with replacement, we need + * two additional passes. */ - def sampleByKey(withReplacement: Boolean, + @Experimental + def sampleByKeyExact(withReplacement: Boolean, fractions: JMap[K, Double], seed: Long): JavaPairRDD[K, V] = - sampleByKey(withReplacement, fractions, false, seed) + new JavaPairRDD[K, V](rdd.sampleByKeyExact(withReplacement, fractions, seed)) /** - * Return a subset of this RDD sampled by key (via stratified sampling). + * ::Experimental:: + * Return a subset of this RDD sampled by key (via stratified sampling) containing exactly + * math.ceil(numItems * samplingRate) for each stratum (group of pairs with the same key). * - * Create a sample of this RDD using variable sampling rates for different keys as specified by - * `fractions`, a key to sampling rate map. - * - * Produce a sample of size that's approximately equal to the sum of - * math.ceil(numItems * samplingRate) over all key values with one pass over the RDD via - * simple random sampling. + * This method differs from [[sampleByKey]] in that we make additional passes over the RDD to + * create a sample size that's exactly equal to the sum of math.ceil(numItems * samplingRate) + * over all key values with a 99.99% confidence. When sampling without replacement, we need one + * additional pass over the RDD to guarantee sample size; when sampling with replacement, we need + * two additional passes. * - * Use Utils.random.nextLong as the default seed for the random number generator + * Use Utils.random.nextLong as the default seed for the random number generator. */ - def sampleByKey(withReplacement: Boolean, fractions: JMap[K, Double]): JavaPairRDD[K, V] = - sampleByKey(withReplacement, fractions, false, Utils.random.nextLong) + @Experimental + def sampleByKeyExact(withReplacement: Boolean, fractions: JMap[K, Double]): JavaPairRDD[K, V] = + sampleByKeyExact(withReplacement, fractions, Utils.random.nextLong) /** * Return the union of this RDD and another one. Any identical elements will appear multiple diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index 5dd6472b0776c..f6d9d12fe9006 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -197,33 +197,56 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * Return a subset of this RDD sampled by key (via stratified sampling). * * Create a sample of this RDD using variable sampling rates for different keys as specified by - * `fractions`, a key to sampling rate map. - * - * If `exact` is set to false, create the sample via simple random sampling, with one pass - * over the RDD, to produce a sample of size that's approximately equal to the sum of - * math.ceil(numItems * samplingRate) over all key values; otherwise, use - * additional passes over the RDD to create a sample size that's exactly equal to the sum of - * math.ceil(numItems * samplingRate) over all key values with a 99.99% confidence. When sampling - * without replacement, we need one additional pass over the RDD to guarantee sample size; - * when sampling with replacement, we need two additional passes. + * `fractions`, a key to sampling rate map, via simple random sampling with one pass over the + * RDD, to produce a sample of size that's approximately equal to the sum of + * math.ceil(numItems * samplingRate) over all key values. * * @param withReplacement whether to sample with or without replacement * @param fractions map of specific keys to sampling rates * @param seed seed for the random number generator - * @param exact whether sample size needs to be exactly math.ceil(fraction * size) per key * @return RDD containing the sampled subset */ def sampleByKey(withReplacement: Boolean, fractions: Map[K, Double], - exact: Boolean = false, - seed: Long = Utils.random.nextLong): RDD[(K, V)]= { + seed: Long = Utils.random.nextLong): RDD[(K, V)] = { + + require(fractions.values.forall(v => v >= 0.0), "Negative sampling rates.") + + val samplingFunc = if (withReplacement) { + StratifiedSamplingUtils.getPoissonSamplingFunction(self, fractions, false, seed) + } else { + StratifiedSamplingUtils.getBernoulliSamplingFunction(self, fractions, false, seed) + } + self.mapPartitionsWithIndex(samplingFunc, preservesPartitioning = true) + } + + /** + * ::Experimental:: + * Return a subset of this RDD sampled by key (via stratified sampling) containing exactly + * math.ceil(numItems * samplingRate) for each stratum (group of pairs with the same key). + * + * This method differs from [[sampleByKey]] in that we make additional passes over the RDD to + * create a sample size that's exactly equal to the sum of math.ceil(numItems * samplingRate) + * over all key values with a 99.99% confidence. When sampling without replacement, we need one + * additional pass over the RDD to guarantee sample size; when sampling with replacement, we need + * two additional passes. + * + * @param withReplacement whether to sample with or without replacement + * @param fractions map of specific keys to sampling rates + * @param seed seed for the random number generator + * @return RDD containing the sampled subset + */ + @Experimental + def sampleByKeyExact(withReplacement: Boolean, + fractions: Map[K, Double], + seed: Long = Utils.random.nextLong): RDD[(K, V)] = { require(fractions.values.forall(v => v >= 0.0), "Negative sampling rates.") val samplingFunc = if (withReplacement) { - StratifiedSamplingUtils.getPoissonSamplingFunction(self, fractions, exact, seed) + StratifiedSamplingUtils.getPoissonSamplingFunction(self, fractions, true, seed) } else { - StratifiedSamplingUtils.getBernoulliSamplingFunction(self, fractions, exact, seed) + StratifiedSamplingUtils.getBernoulliSamplingFunction(self, fractions, true, seed) } self.mapPartitionsWithIndex(samplingFunc, preservesPartitioning = true) } diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java index 56150caa5d6ba..e1c13de04a0be 100644 --- a/core/src/test/java/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java @@ -1239,12 +1239,28 @@ public Tuple2 call(Integer i) { Assert.assertTrue(worCounts.size() == 2); Assert.assertTrue(worCounts.get(0) > 0); Assert.assertTrue(worCounts.get(1) > 0); - JavaPairRDD wrExact = rdd2.sampleByKey(true, fractions, true, 1L); + } + + @Test + @SuppressWarnings("unchecked") + public void sampleByKeyExact() { + JavaRDD rdd1 = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8), 3); + JavaPairRDD rdd2 = rdd1.mapToPair( + new PairFunction() { + @Override + public Tuple2 call(Integer i) { + return new Tuple2(i % 2, 1); + } + }); + Map fractions = Maps.newHashMap(); + fractions.put(0, 0.5); + fractions.put(1, 1.0); + JavaPairRDD wrExact = rdd2.sampleByKeyExact(true, fractions, 1L); Map wrExactCounts = (Map) (Object) wrExact.countByKey(); Assert.assertTrue(wrExactCounts.size() == 2); Assert.assertTrue(wrExactCounts.get(0) == 2); Assert.assertTrue(wrExactCounts.get(1) == 4); - JavaPairRDD worExact = rdd2.sampleByKey(false, fractions, true, 1L); + JavaPairRDD worExact = rdd2.sampleByKeyExact(false, fractions, 1L); Map worExactCounts = (Map) (Object) worExact.countByKey(); Assert.assertTrue(worExactCounts.size() == 2); Assert.assertTrue(worExactCounts.get(0) == 2); diff --git a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala index 4f49d4a1d4d34..63d3ddb4af98a 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala @@ -84,118 +84,81 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext { } test("sampleByKey") { - def stratifier (fractionPositive: Double) = { - (x: Int) => if (x % 10 < (10 * fractionPositive).toInt) "1" else "0" - } - def checkSize(exact: Boolean, - withReplacement: Boolean, - expected: Long, - actual: Long, - p: Double): Boolean = { - if (exact) { - return expected == actual - } - val stdev = if (withReplacement) math.sqrt(expected) else math.sqrt(expected * p * (1 - p)) - // Very forgiving margin since we're dealing with very small sample sizes most of the time - math.abs(actual - expected) <= 6 * stdev + val defaultSeed = 1L + + // vary RDD size + for (n <- List(100, 1000, 1000000)) { + val data = sc.parallelize(1 to n, 2) + val fractionPositive = 0.3 + val stratifiedData = data.keyBy(StratifiedAuxiliary.stratifier(fractionPositive)) + val samplingRate = 0.1 + StratifiedAuxiliary.testSample(stratifiedData, samplingRate, defaultSeed, n) } - // Without replacement validation - def takeSampleAndValidateBernoulli(stratifiedData: RDD[(String, Int)], - exact: Boolean, - samplingRate: Double, - seed: Long, - n: Long) = { - val expectedSampleSize = stratifiedData.countByKey() - .mapValues(count => math.ceil(count * samplingRate).toInt) - val fractions = Map("1" -> samplingRate, "0" -> samplingRate) - val sample = stratifiedData.sampleByKey(false, fractions, exact, seed) - val sampleCounts = sample.countByKey() - val takeSample = sample.collect() - sampleCounts.foreach { case(k, v) => - assert(checkSize(exact, false, expectedSampleSize(k), v, samplingRate)) } - assert(takeSample.size === takeSample.toSet.size) - takeSample.foreach { x => assert(1 <= x._2 && x._2 <= n, s"elements not in [1, $n]") } + // vary fractionPositive + for (fractionPositive <- List(0.1, 0.3, 0.5, 0.7, 0.9)) { + val n = 100 + val data = sc.parallelize(1 to n, 2) + val stratifiedData = data.keyBy(StratifiedAuxiliary.stratifier(fractionPositive)) + val samplingRate = 0.1 + StratifiedAuxiliary.testSample(stratifiedData, samplingRate, defaultSeed, n) } - // With replacement validation - def takeSampleAndValidatePoisson(stratifiedData: RDD[(String, Int)], - exact: Boolean, - samplingRate: Double, - seed: Long, - n: Long) = { - val expectedSampleSize = stratifiedData.countByKey().mapValues(count => - math.ceil(count * samplingRate).toInt) - val fractions = Map("1" -> samplingRate, "0" -> samplingRate) - val sample = stratifiedData.sampleByKey(true, fractions, exact, seed) - val sampleCounts = sample.countByKey() - val takeSample = sample.collect() - sampleCounts.foreach { case(k, v) => - assert(checkSize(exact, true, expectedSampleSize(k), v, samplingRate)) } - val groupedByKey = takeSample.groupBy(_._1) - for ((key, v) <- groupedByKey) { - if (expectedSampleSize(key) >= 100 && samplingRate >= 0.1) { - // sample large enough for there to be repeats with high likelihood - assert(v.toSet.size < expectedSampleSize(key)) - } else { - if (exact) { - assert(v.toSet.size <= expectedSampleSize(key)) - } else { - assert(checkSize(false, true, expectedSampleSize(key), v.toSet.size, samplingRate)) - } - } - } - takeSample.foreach { x => assert(1 <= x._2 && x._2 <= n, s"elements not in [1, $n]") } + // Use the same data for the rest of the tests + val fractionPositive = 0.3 + val n = 100 + val data = sc.parallelize(1 to n, 2) + val stratifiedData = data.keyBy(StratifiedAuxiliary.stratifier(fractionPositive)) + + // vary seed + for (seed <- defaultSeed to defaultSeed + 5L) { + val samplingRate = 0.1 + StratifiedAuxiliary.testSample(stratifiedData, samplingRate, seed, n) } - def checkAllCombos(stratifiedData: RDD[(String, Int)], - samplingRate: Double, - seed: Long, - n: Long) = { - takeSampleAndValidateBernoulli(stratifiedData, true, samplingRate, seed, n) - takeSampleAndValidateBernoulli(stratifiedData, false, samplingRate, seed, n) - takeSampleAndValidatePoisson(stratifiedData, true, samplingRate, seed, n) - takeSampleAndValidatePoisson(stratifiedData, false, samplingRate, seed, n) + // vary sampling rate + for (samplingRate <- List(0.01, 0.05, 0.1, 0.5)) { + StratifiedAuxiliary.testSample(stratifiedData, samplingRate, defaultSeed, n) } + } + test("sampleByKeyExact") { val defaultSeed = 1L // vary RDD size for (n <- List(100, 1000, 1000000)) { val data = sc.parallelize(1 to n, 2) val fractionPositive = 0.3 - val stratifiedData = data.keyBy(stratifier(fractionPositive)) - + val stratifiedData = data.keyBy(StratifiedAuxiliary.stratifier(fractionPositive)) val samplingRate = 0.1 - checkAllCombos(stratifiedData, samplingRate, defaultSeed, n) + StratifiedAuxiliary.testSampleExact(stratifiedData, samplingRate, defaultSeed, n) } // vary fractionPositive for (fractionPositive <- List(0.1, 0.3, 0.5, 0.7, 0.9)) { val n = 100 val data = sc.parallelize(1 to n, 2) - val stratifiedData = data.keyBy(stratifier(fractionPositive)) - + val stratifiedData = data.keyBy(StratifiedAuxiliary.stratifier(fractionPositive)) val samplingRate = 0.1 - checkAllCombos(stratifiedData, samplingRate, defaultSeed, n) + StratifiedAuxiliary.testSampleExact(stratifiedData, samplingRate, defaultSeed, n) } // Use the same data for the rest of the tests val fractionPositive = 0.3 val n = 100 val data = sc.parallelize(1 to n, 2) - val stratifiedData = data.keyBy(stratifier(fractionPositive)) + val stratifiedData = data.keyBy(StratifiedAuxiliary.stratifier(fractionPositive)) // vary seed for (seed <- defaultSeed to defaultSeed + 5L) { val samplingRate = 0.1 - checkAllCombos(stratifiedData, samplingRate, seed, n) + StratifiedAuxiliary.testSampleExact(stratifiedData, samplingRate, seed, n) } // vary sampling rate for (samplingRate <- List(0.01, 0.05, 0.1, 0.5)) { - checkAllCombos(stratifiedData, samplingRate, defaultSeed, n) + StratifiedAuxiliary.testSampleExact(stratifiedData, samplingRate, defaultSeed, n) } } @@ -556,6 +519,98 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext { intercept[IllegalArgumentException] {shuffled.lookup(-1)} } + private object StratifiedAuxiliary { + def stratifier (fractionPositive: Double) = { + (x: Int) => if (x % 10 < (10 * fractionPositive).toInt) "1" else "0" + } + + def checkSize(exact: Boolean, + withReplacement: Boolean, + expected: Long, + actual: Long, + p: Double): Boolean = { + if (exact) { + return expected == actual + } + val stdev = if (withReplacement) math.sqrt(expected) else math.sqrt(expected * p * (1 - p)) + // Very forgiving margin since we're dealing with very small sample sizes most of the time + math.abs(actual - expected) <= 6 * stdev + } + + def testSampleExact(stratifiedData: RDD[(String, Int)], + samplingRate: Double, + seed: Long, + n: Long) = { + testBernoulli(stratifiedData, true, samplingRate, seed, n) + testPoisson(stratifiedData, true, samplingRate, seed, n) + } + + def testSample(stratifiedData: RDD[(String, Int)], + samplingRate: Double, + seed: Long, + n: Long) = { + testBernoulli(stratifiedData, false, samplingRate, seed, n) + testPoisson(stratifiedData, false, samplingRate, seed, n) + } + + // Without replacement validation + def testBernoulli(stratifiedData: RDD[(String, Int)], + exact: Boolean, + samplingRate: Double, + seed: Long, + n: Long) = { + val expectedSampleSize = stratifiedData.countByKey() + .mapValues(count => math.ceil(count * samplingRate).toInt) + val fractions = Map("1" -> samplingRate, "0" -> samplingRate) + val sample = if (exact) { + stratifiedData.sampleByKeyExact(false, fractions, seed) + } else { + stratifiedData.sampleByKey(false, fractions, seed) + } + val sampleCounts = sample.countByKey() + val takeSample = sample.collect() + sampleCounts.foreach { case(k, v) => + assert(checkSize(exact, false, expectedSampleSize(k), v, samplingRate)) } + assert(takeSample.size === takeSample.toSet.size) + takeSample.foreach { x => assert(1 <= x._2 && x._2 <= n, s"elements not in [1, $n]") } + } + + // With replacement validation + def testPoisson(stratifiedData: RDD[(String, Int)], + exact: Boolean, + samplingRate: Double, + seed: Long, + n: Long) = { + val expectedSampleSize = stratifiedData.countByKey().mapValues(count => + math.ceil(count * samplingRate).toInt) + val fractions = Map("1" -> samplingRate, "0" -> samplingRate) + val sample = if (exact) { + stratifiedData.sampleByKeyExact(true, fractions, seed) + } else { + stratifiedData.sampleByKey(true, fractions, seed) + } + val sampleCounts = sample.countByKey() + val takeSample = sample.collect() + sampleCounts.foreach { case (k, v) => + assert(checkSize(exact, true, expectedSampleSize(k), v, samplingRate)) + } + val groupedByKey = takeSample.groupBy(_._1) + for ((key, v) <- groupedByKey) { + if (expectedSampleSize(key) >= 100 && samplingRate >= 0.1) { + // sample large enough for there to be repeats with high likelihood + assert(v.toSet.size < expectedSampleSize(key)) + } else { + if (exact) { + assert(v.toSet.size <= expectedSampleSize(key)) + } else { + assert(checkSize(false, true, expectedSampleSize(key), v.toSet.size, samplingRate)) + } + } + } + takeSample.foreach(x => assert(1 <= x._2 && x._2 <= n, s"elements not in [1, $n]")) + } + } + } /* From 09b8a3ce0d73915d573e0ebc3e96448736b89bfa Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 11 Aug 2014 11:54:09 -0700 Subject: [PATCH 095/231] [PySpark] [SPARK-2954] [SPARK-2948] [SPARK-2910] [SPARK-2101] Python 2.6 Fixes - Modify python/run-tests to test with Python 2.6 - Use unittest2 when running on Python 2.6. - Fix issue with namedtuple. - Skip TestOutputFormat.test_newhadoop on Python 2.6 until SPARK-2951 is fixed. - Fix MLlib _deserialize_double on Python 2.6. Closes #1868. Closes #1042. Author: Josh Rosen Closes #1874 from JoshRosen/python2.6 and squashes the following commits: 983d259 [Josh Rosen] [SPARK-2954] Fix MLlib _deserialize_double on Python 2.6. 5d18fd7 [Josh Rosen] [SPARK-2948] [SPARK-2910] [SPARK-2101] Python 2.6 fixes (cherry picked from commit db06a81fb7a413faa3fe0f8c35918f70454cb05d) Signed-off-by: Josh Rosen --- python/pyspark/mllib/_common.py | 11 ++++++++++- python/pyspark/mllib/tests.py | 7 ++++++- python/pyspark/serializers.py | 4 ++-- python/pyspark/tests.py | 13 ++++++++++--- python/run-tests | 8 ++++++++ 5 files changed, 36 insertions(+), 7 deletions(-) diff --git a/python/pyspark/mllib/_common.py b/python/pyspark/mllib/_common.py index db341da85f865..bb60d3d0c8463 100644 --- a/python/pyspark/mllib/_common.py +++ b/python/pyspark/mllib/_common.py @@ -16,6 +16,7 @@ # import struct +import sys import numpy from numpy import ndarray, float64, int64, int32, array_equal, array from pyspark import SparkContext, RDD @@ -78,6 +79,14 @@ LABELED_POINT_MAGIC = 4 +# Workaround for SPARK-2954: before Python 2.7, struct.unpack couldn't unpack bytearray()s. +if sys.version_info[:2] <= (2, 6): + def _unpack(fmt, string): + return struct.unpack(fmt, buffer(string)) +else: + _unpack = struct.unpack + + def _deserialize_numpy_array(shape, ba, offset, dtype=float64): """ Deserialize a numpy array of the given type from an offset in @@ -191,7 +200,7 @@ def _deserialize_double(ba, offset=0): raise TypeError("_deserialize_double called on a %s; wanted bytearray" % type(ba)) if len(ba) - offset != 8: raise TypeError("_deserialize_double called on a %d-byte array; wanted 8 bytes." % nb) - return struct.unpack("d", ba[offset:])[0] + return _unpack("d", ba[offset:])[0] def _deserialize_double_vector(ba, offset=0): diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 6f3ec8ac94bac..8a851bd35c0e8 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -19,8 +19,13 @@ Fuller unit tests for Python MLlib. """ +import sys from numpy import array, array_equal -import unittest + +if sys.version_info[:2] <= (2, 6): + import unittest2 as unittest +else: + import unittest from pyspark.mllib._common import _convert_vector, _serialize_double_vector, \ _deserialize_double_vector, _dot, _squared_distance diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index b35558db3e007..df90cafb245bf 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -314,8 +314,8 @@ def _copy_func(f): _old_namedtuple = _copy_func(collections.namedtuple) - def namedtuple(name, fields, verbose=False, rename=False): - cls = _old_namedtuple(name, fields, verbose, rename) + def namedtuple(*args, **kwargs): + cls = _old_namedtuple(*args, **kwargs) return _hack_namedtuple(cls) # replace namedtuple with new one diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 88a61176e51ab..22b51110ed671 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -29,9 +29,14 @@ import sys import tempfile import time -import unittest import zipfile +if sys.version_info[:2] <= (2, 6): + import unittest2 as unittest +else: + import unittest + + from pyspark.context import SparkContext from pyspark.files import SparkFiles from pyspark.serializers import read_int @@ -605,6 +610,7 @@ def test_oldhadoop(self): conf=input_conf).collect()) self.assertEqual(old_dataset, dict_data) + @unittest.skipIf(sys.version_info[:2] <= (2, 6), "Skipped on 2.6 until SPARK-2951 is fixed") def test_newhadoop(self): basepath = self.tempdir.name # use custom ArrayWritable types and converters to handle arrays @@ -905,8 +911,9 @@ def createFileInZip(self, name, content): pattern = re.compile(r'^ *\|', re.MULTILINE) content = re.sub(pattern, '', content.strip()) path = os.path.join(self.programDir, name + ".zip") - with zipfile.ZipFile(path, 'w') as zip: - zip.writestr(name, content) + zip = zipfile.ZipFile(path, 'w') + zip.writestr(name, content) + zip.close() return path def test_single_script(self): diff --git a/python/run-tests b/python/run-tests index 48feba2f5bd63..1218edcbd7e08 100755 --- a/python/run-tests +++ b/python/run-tests @@ -48,6 +48,14 @@ function run_test() { echo "Running PySpark tests. Output is in python/unit-tests.log." +# Try to test with Python 2.6, since that's the minimum version that we support: +if [ $(which python2.6) ]; then + export PYSPARK_PYTHON="python2.6" +fi + +echo "Testing with Python version:" +$PYSPARK_PYTHON --version + run_test "pyspark/rdd.py" run_test "pyspark/context.py" run_test "pyspark/conf.py" From 6ec13745093e983836098c5828a4d4f4e8cc2f54 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 11 Aug 2014 15:25:21 -0700 Subject: [PATCH 096/231] [SPARK-2952] Enable logging actor messages at DEBUG level Example messages: ``` 14/08/09 21:37:01 DEBUG BlockManagerMasterActor: [actor] received message RegisterBlockManager(BlockManagerId(0, rxin-mbp, 58092, 0),278302556,Actor[akka.tcp://spark@rxin-mbp:58088/user/BlockManagerActor1#-63596539]) from Actor[akka.tcp://spark@rxin-mbp:58088/temp/$c] 14/08/09 21:37:01 DEBUG BlockManagerMasterActor: [actor] handled message (0.279 ms) RegisterBlockManager(BlockManagerId(0, rxin-mbp, 58092, 0),278302556,Actor[akka.tcp://spark@rxin-mbp:58088/user/BlockManagerActor1#-63596539]) from Actor[akka.tcp://spark@rxin-mbp:58088/temp/$c] ``` cc @mengxr @tdas @pwendell Author: Reynold Xin Closes #1870 from rxin/actorLogging and squashes the following commits: c531ee5 [Reynold Xin] Added license header for ActorLogReceive. f6b1ebe [Reynold Xin] [SPARK-2952] Enable logging actor messages at DEBUG level (cherry picked from commit 37338666655909502e424b4639d680271d6d4c12) Signed-off-by: Reynold Xin --- .../org/apache/spark/HeartbeatReceiver.scala | 7 +- .../org/apache/spark/MapOutputTracker.scala | 4 +- .../org/apache/spark/deploy/Client.scala | 8 ++- .../spark/deploy/client/AppClient.scala | 6 +- .../apache/spark/deploy/master/Master.scala | 6 +- .../apache/spark/deploy/worker/Worker.scala | 6 +- .../spark/deploy/worker/WorkerWatcher.scala | 8 ++- .../CoarseGrainedExecutorBackend.scala | 7 +- .../CoarseGrainedSchedulerBackend.scala | 9 ++- .../spark/scheduler/local/LocalBackend.scala | 8 +-- .../storage/BlockManagerMasterActor.scala | 11 ++-- .../storage/BlockManagerSlaveActor.scala | 5 +- .../apache/spark/util/ActorLogReceive.scala | 64 +++++++++++++++++++ 13 files changed, 111 insertions(+), 38 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/util/ActorLogReceive.scala diff --git a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala index 24ccce21b62ca..83ae57b7f1516 100644 --- a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala +++ b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala @@ -21,6 +21,7 @@ import akka.actor.Actor import org.apache.spark.executor.TaskMetrics import org.apache.spark.storage.BlockManagerId import org.apache.spark.scheduler.TaskScheduler +import org.apache.spark.util.ActorLogReceive /** * A heartbeat from executors to the driver. This is a shared message used by several internal @@ -36,8 +37,10 @@ private[spark] case class HeartbeatResponse(reregisterBlockManager: Boolean) /** * Lives in the driver to receive heartbeats from executors.. */ -private[spark] class HeartbeatReceiver(scheduler: TaskScheduler) extends Actor { - override def receive = { +private[spark] class HeartbeatReceiver(scheduler: TaskScheduler) + extends Actor with ActorLogReceive with Logging { + + override def receiveWithLogging = { case Heartbeat(executorId, taskMetrics, blockManagerId) => val response = HeartbeatResponse( !scheduler.executorHeartbeatReceived(executorId, taskMetrics, blockManagerId)) diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 894091761485d..51705c895a55c 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -38,10 +38,10 @@ private[spark] case object StopMapOutputTracker extends MapOutputTrackerMessage /** Actor class for MapOutputTrackerMaster */ private[spark] class MapOutputTrackerMasterActor(tracker: MapOutputTrackerMaster, conf: SparkConf) - extends Actor with Logging { + extends Actor with ActorLogReceive with Logging { val maxAkkaFrameSize = AkkaUtils.maxFrameSizeBytes(conf) - def receive = { + override def receiveWithLogging = { case GetMapOutputStatuses(shuffleId: Int) => val hostPort = sender.path.address.hostPort logInfo("Asked to send map output locations for shuffle " + shuffleId + " to " + hostPort) diff --git a/core/src/main/scala/org/apache/spark/deploy/Client.scala b/core/src/main/scala/org/apache/spark/deploy/Client.scala index c07003784e8ac..065ddda50e65e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/Client.scala +++ b/core/src/main/scala/org/apache/spark/deploy/Client.scala @@ -27,12 +27,14 @@ import org.apache.log4j.{Level, Logger} import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.deploy.DeployMessages._ import org.apache.spark.deploy.master.{DriverState, Master} -import org.apache.spark.util.{AkkaUtils, Utils} +import org.apache.spark.util.{ActorLogReceive, AkkaUtils, Utils} /** * Proxy that relays messages to the driver. */ -private class ClientActor(driverArgs: ClientArguments, conf: SparkConf) extends Actor with Logging { +private class ClientActor(driverArgs: ClientArguments, conf: SparkConf) + extends Actor with ActorLogReceive with Logging { + var masterActor: ActorSelection = _ val timeout = AkkaUtils.askTimeout(conf) @@ -114,7 +116,7 @@ private class ClientActor(driverArgs: ClientArguments, conf: SparkConf) extends } } - override def receive = { + override def receiveWithLogging = { case SubmitDriverResponse(success, driverId, message) => println(message) diff --git a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala index d38e9e79204c2..32790053a6be8 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala @@ -30,7 +30,7 @@ import org.apache.spark.{Logging, SparkConf, SparkException} import org.apache.spark.deploy.{ApplicationDescription, ExecutorState} import org.apache.spark.deploy.DeployMessages._ import org.apache.spark.deploy.master.Master -import org.apache.spark.util.{Utils, AkkaUtils} +import org.apache.spark.util.{ActorLogReceive, Utils, AkkaUtils} /** * Interface allowing applications to speak with a Spark deploy cluster. Takes a master URL, @@ -56,7 +56,7 @@ private[spark] class AppClient( var registered = false var activeMasterUrl: String = null - class ClientActor extends Actor with Logging { + class ClientActor extends Actor with ActorLogReceive with Logging { var master: ActorSelection = null var alreadyDisconnected = false // To avoid calling listener.disconnected() multiple times var alreadyDead = false // To avoid calling listener.dead() multiple times @@ -119,7 +119,7 @@ private[spark] class AppClient( .contains(remoteUrl.hostPort) } - override def receive = { + override def receiveWithLogging = { case RegisteredApplication(appId_, masterUrl) => appId = appId_ registered = true diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index a70ecdb375373..cfa2c028a807b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -42,14 +42,14 @@ import org.apache.spark.deploy.master.ui.MasterWebUI import org.apache.spark.metrics.MetricsSystem import org.apache.spark.scheduler.{EventLoggingListener, ReplayListenerBus} import org.apache.spark.ui.SparkUI -import org.apache.spark.util.{AkkaUtils, SignalLogger, Utils} +import org.apache.spark.util.{ActorLogReceive, AkkaUtils, SignalLogger, Utils} private[spark] class Master( host: String, port: Int, webUiPort: Int, val securityMgr: SecurityManager) - extends Actor with Logging { + extends Actor with ActorLogReceive with Logging { import context.dispatcher // to use Akka's scheduler.schedule() @@ -167,7 +167,7 @@ private[spark] class Master( context.stop(leaderElectionAgent) } - override def receive = { + override def receiveWithLogging = { case ElectedLeader => { val (storedApps, storedDrivers, storedWorkers) = persistenceEngine.readPersistedData() state = if (storedApps.isEmpty && storedDrivers.isEmpty && storedWorkers.isEmpty) { diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index 458d9947bd873..da4fa2f7685d1 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -34,7 +34,7 @@ import org.apache.spark.deploy.DeployMessages._ import org.apache.spark.deploy.master.{DriverState, Master} import org.apache.spark.deploy.worker.ui.WorkerWebUI import org.apache.spark.metrics.MetricsSystem -import org.apache.spark.util.{AkkaUtils, SignalLogger, Utils} +import org.apache.spark.util.{ActorLogReceive, AkkaUtils, SignalLogger, Utils} /** * @param masterUrls Each url should look like spark://host:port. @@ -51,7 +51,7 @@ private[spark] class Worker( workDirPath: String = null, val conf: SparkConf, val securityMgr: SecurityManager) - extends Actor with Logging { + extends Actor with ActorLogReceive with Logging { import context.dispatcher Utils.checkHost(host, "Expected hostname") @@ -187,7 +187,7 @@ private[spark] class Worker( } } - override def receive = { + override def receiveWithLogging = { case RegisteredWorker(masterUrl, masterWebUiUrl) => logInfo("Successfully registered with master " + masterUrl) registered = true diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala index 530c147000904..6d0d0bbe5ecec 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala @@ -22,13 +22,15 @@ import akka.remote.{AssociatedEvent, AssociationErrorEvent, AssociationEvent, Di import org.apache.spark.Logging import org.apache.spark.deploy.DeployMessages.SendHeartbeat +import org.apache.spark.util.ActorLogReceive /** * Actor which connects to a worker process and terminates the JVM if the connection is severed. * Provides fate sharing between a worker and its associated child processes. */ -private[spark] class WorkerWatcher(workerUrl: String) extends Actor - with Logging { +private[spark] class WorkerWatcher(workerUrl: String) + extends Actor with ActorLogReceive with Logging { + override def preStart() { context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) @@ -48,7 +50,7 @@ private[spark] class WorkerWatcher(workerUrl: String) extends Actor def exitNonZero() = if (isTesting) isShutDown = true else System.exit(-1) - override def receive = { + override def receiveWithLogging = { case AssociatedEvent(localAddress, remoteAddress, inbound) if isWorker(remoteAddress) => logInfo(s"Successfully connected to $workerUrl") diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index 1f46a0f176490..13af5b6f5812d 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -31,14 +31,15 @@ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.deploy.worker.WorkerWatcher import org.apache.spark.scheduler.TaskDescription import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ -import org.apache.spark.util.{AkkaUtils, SignalLogger, Utils} +import org.apache.spark.util.{ActorLogReceive, AkkaUtils, SignalLogger, Utils} private[spark] class CoarseGrainedExecutorBackend( driverUrl: String, executorId: String, hostPort: String, cores: Int, - sparkProperties: Seq[(String, String)]) extends Actor with ExecutorBackend with Logging { + sparkProperties: Seq[(String, String)]) + extends Actor with ActorLogReceive with ExecutorBackend with Logging { Utils.checkHostPort(hostPort, "Expected hostport") @@ -52,7 +53,7 @@ private[spark] class CoarseGrainedExecutorBackend( context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) } - override def receive = { + override def receiveWithLogging = { case RegisteredExecutor => logInfo("Successfully registered with driver") // Make this host instead of hostPort ? diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 33500d967ebb1..2a3711ae2a78c 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -30,7 +30,7 @@ import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent} import org.apache.spark.{SparkEnv, Logging, SparkException, TaskState} import org.apache.spark.scheduler.{SchedulerBackend, SlaveLost, TaskDescription, TaskSchedulerImpl, WorkerOffer} import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ -import org.apache.spark.util.{SerializableBuffer, AkkaUtils, Utils} +import org.apache.spark.util.{ActorLogReceive, SerializableBuffer, AkkaUtils, Utils} import org.apache.spark.ui.JettyUtils /** @@ -61,7 +61,10 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A conf.getInt("spark.scheduler.maxRegisteredResourcesWaitingTime", 30000) val createTime = System.currentTimeMillis() - class DriverActor(sparkProperties: Seq[(String, String)]) extends Actor { + class DriverActor(sparkProperties: Seq[(String, String)]) extends Actor with ActorLogReceive { + + override protected def log = CoarseGrainedSchedulerBackend.this.log + private val executorActor = new HashMap[String, ActorRef] private val executorAddress = new HashMap[String, Address] private val executorHost = new HashMap[String, String] @@ -79,7 +82,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A context.system.scheduler.schedule(0.millis, reviveInterval.millis, self, ReviveOffers) } - def receive = { + def receiveWithLogging = { case RegisterExecutor(executorId, hostPort, cores) => Utils.checkHostPort(hostPort, "Host port expected " + hostPort) if (executorActor.contains(executorId)) { diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala index 3d1cf312ccc97..bec9502f20466 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala @@ -23,9 +23,9 @@ import akka.actor.{Actor, ActorRef, Props} import org.apache.spark.{Logging, SparkEnv, TaskState} import org.apache.spark.TaskState.TaskState -import org.apache.spark.executor.{TaskMetrics, Executor, ExecutorBackend} +import org.apache.spark.executor.{Executor, ExecutorBackend} import org.apache.spark.scheduler.{SchedulerBackend, TaskSchedulerImpl, WorkerOffer} -import org.apache.spark.storage.BlockManagerId +import org.apache.spark.util.ActorLogReceive private case class ReviveOffers() @@ -43,7 +43,7 @@ private case class StopExecutor() private[spark] class LocalActor( scheduler: TaskSchedulerImpl, executorBackend: LocalBackend, - private val totalCores: Int) extends Actor with Logging { + private val totalCores: Int) extends Actor with ActorLogReceive with Logging { private var freeCores = totalCores @@ -53,7 +53,7 @@ private[spark] class LocalActor( val executor = new Executor( localExecutorId, localExecutorHostname, scheduler.conf.getAll, isLocal = true) - def receive = { + override def receiveWithLogging = { case ReviveOffers => reviveOffers() diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala index bd31e3c5a187f..3ab07703b6f85 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala @@ -31,7 +31,7 @@ import org.apache.spark.{Logging, SparkConf, SparkException} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.scheduler._ import org.apache.spark.storage.BlockManagerMessages._ -import org.apache.spark.util.{AkkaUtils, Utils} +import org.apache.spark.util.{ActorLogReceive, AkkaUtils, Utils} /** * BlockManagerMasterActor is an actor on the master node to track statuses of @@ -39,7 +39,7 @@ import org.apache.spark.util.{AkkaUtils, Utils} */ private[spark] class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus: LiveListenerBus) - extends Actor with Logging { + extends Actor with ActorLogReceive with Logging { // Mapping from block manager id to the block manager's information. private val blockManagerInfo = new mutable.HashMap[BlockManagerId, BlockManagerInfo] @@ -55,8 +55,7 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus val slaveTimeout = conf.getLong("spark.storage.blockManagerSlaveTimeoutMs", math.max(conf.getInt("spark.executor.heartbeatInterval", 10000) * 3, 45000)) - val checkTimeoutInterval = conf.getLong("spark.storage.blockManagerTimeoutIntervalMs", - 60000) + val checkTimeoutInterval = conf.getLong("spark.storage.blockManagerTimeoutIntervalMs", 60000) var timeoutCheckingTask: Cancellable = null @@ -67,9 +66,8 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus super.preStart() } - def receive = { + override def receiveWithLogging = { case RegisterBlockManager(blockManagerId, maxMemSize, slaveActor) => - logInfo("received a register") register(blockManagerId, maxMemSize, slaveActor) sender ! true @@ -118,7 +116,6 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus sender ! true case StopBlockManagerMaster => - logInfo("Stopping BlockManagerMaster") sender ! true if (timeoutCheckingTask != null) { timeoutCheckingTask.cancel() diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala index 6d4db064dff58..c194e0fed3367 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala @@ -23,6 +23,7 @@ import akka.actor.{ActorRef, Actor} import org.apache.spark.{Logging, MapOutputTracker} import org.apache.spark.storage.BlockManagerMessages._ +import org.apache.spark.util.ActorLogReceive /** * An actor to take commands from the master to execute options. For example, @@ -32,12 +33,12 @@ private[storage] class BlockManagerSlaveActor( blockManager: BlockManager, mapOutputTracker: MapOutputTracker) - extends Actor with Logging { + extends Actor with ActorLogReceive with Logging { import context.dispatcher // Operations that involve removing blocks may be slow and should be done asynchronously - override def receive = { + override def receiveWithLogging = { case RemoveBlock(blockId) => doAsync[Boolean]("removing block " + blockId, sender) { blockManager.removeBlock(blockId) diff --git a/core/src/main/scala/org/apache/spark/util/ActorLogReceive.scala b/core/src/main/scala/org/apache/spark/util/ActorLogReceive.scala new file mode 100644 index 0000000000000..332d0cbb2dc0c --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/ActorLogReceive.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import akka.actor.Actor +import org.slf4j.Logger + +/** + * A trait to enable logging all Akka actor messages. Here's an example of using this: + * + * {{{ + * class BlockManagerMasterActor extends Actor with ActorLogReceive with Logging { + * ... + * override def receiveWithLogging = { + * case GetLocations(blockId) => + * sender ! getLocations(blockId) + * ... + * } + * ... + * } + * }}} + * + */ +private[spark] trait ActorLogReceive { + self: Actor => + + override def receive: Actor.Receive = new Actor.Receive { + + private val _receiveWithLogging = receiveWithLogging + + override def isDefinedAt(o: Any): Boolean = _receiveWithLogging.isDefinedAt(o) + + override def apply(o: Any): Unit = { + if (log.isDebugEnabled) { + log.debug(s"[actor] received message $o from ${self.sender}") + } + val start = System.nanoTime + _receiveWithLogging.apply(o) + val timeTaken = (System.nanoTime - start).toDouble / 1000000 + if (log.isDebugEnabled) { + log.debug(s"[actor] handled message ($timeTaken ms) $o from ${self.sender}") + } + } + } + + def receiveWithLogging: Actor.Receive + + protected def log: Logger +} From 6c64d57fabd8ec08dcc03cdc94381ee7d431fbcf Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 11 Aug 2014 19:15:01 -0700 Subject: [PATCH 097/231] [SPARK-2931] In TaskSetManager, reset currentLocalityIndex after recomputing locality levels This addresses SPARK-2931, a bug where getAllowedLocalityLevel() could throw ArrayIndexOutOfBoundsException. The fix here is to reset currentLocalityIndex after recomputing the locality levels. Thanks to kayousterhout, mridulm, and lirui-intel for helping me to debug this. Author: Josh Rosen Closes #1896 from JoshRosen/SPARK-2931 and squashes the following commits: 48b60b5 [Josh Rosen] Move FakeRackUtil.cleanUp() info beforeEach(). 6fec474 [Josh Rosen] Set currentLocalityIndex after recomputing locality levels. 9384897 [Josh Rosen] Update SPARK-2931 test to reflect changes in 63bdb1f41b4895e3a9444f7938094438a94d3007. 9ecd455 [Josh Rosen] Apply @mridulm's patch for reproducing SPARK-2931. (cherry picked from commit 7712e724ad69dd0b83754e938e9799d13a4d43b9) Signed-off-by: Josh Rosen --- .../spark/scheduler/TaskSetManager.scala | 11 +++-- .../spark/scheduler/TaskSetManagerSuite.scala | 40 ++++++++++++++++++- 2 files changed, 46 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 20a4bd12f93f6..d9d53faf843ff 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -690,8 +690,7 @@ private[spark] class TaskSetManager( handleFailedTask(tid, TaskState.FAILED, ExecutorLostFailure) } // recalculate valid locality levels and waits when executor is lost - myLocalityLevels = computeValidLocalityLevels() - localityWaits = myLocalityLevels.map(getLocalityWait) + recomputeLocality() } /** @@ -775,9 +774,15 @@ private[spark] class TaskSetManager( levels.toArray } - def executorAdded() { + def recomputeLocality() { + val previousLocalityLevel = myLocalityLevels(currentLocalityIndex) myLocalityLevels = computeValidLocalityLevels() localityWaits = myLocalityLevels.map(getLocalityWait) + currentLocalityIndex = getLocalityIndex(previousLocalityLevel) + } + + def executorAdded() { + recomputeLocality() } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index ffd23380a886f..93e8ddacf8865 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -154,6 +154,11 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { val LOCALITY_WAIT = conf.getLong("spark.locality.wait", 3000) val MAX_TASK_FAILURES = 4 + override def beforeEach() { + super.beforeEach() + FakeRackUtil.cleanUp() + } + test("TaskSet with no preferences") { sc = new SparkContext("local", "test") val sched = new FakeTaskScheduler(sc, ("exec1", "host1")) @@ -471,7 +476,6 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { test("new executors get added and lost") { // Assign host2 to rack2 - FakeRackUtil.cleanUp() FakeRackUtil.assignHostToRack("host2", "rack2") sc = new SparkContext("local", "test") val sched = new FakeTaskScheduler(sc) @@ -504,7 +508,6 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { } test("test RACK_LOCAL tasks") { - FakeRackUtil.cleanUp() // Assign host1 to rack1 FakeRackUtil.assignHostToRack("host1", "rack1") // Assign host2 to rack1 @@ -607,6 +610,39 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { assert(manager.resourceOffer("execA", "host3", NO_PREF).get.index === 2) } + test("Ensure TaskSetManager is usable after addition of levels") { + // Regression test for SPARK-2931 + sc = new SparkContext("local", "test") + val sched = new FakeTaskScheduler(sc) + val taskSet = FakeTask.createTaskSet(2, + Seq(TaskLocation("host1", "execA")), + Seq(TaskLocation("host2", "execB.1"))) + val clock = new FakeClock + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) + // Only ANY is valid + assert(manager.myLocalityLevels.sameElements(Array(ANY))) + // Add a new executor + sched.addExecutor("execA", "host1") + sched.addExecutor("execB.2", "host2") + manager.executorAdded() + assert(manager.pendingTasksWithNoPrefs.size === 0) + // Valid locality should contain PROCESS_LOCAL, NODE_LOCAL and ANY + assert(manager.myLocalityLevels.sameElements(Array(PROCESS_LOCAL, NODE_LOCAL, ANY))) + assert(manager.resourceOffer("execA", "host1", ANY) !== None) + clock.advance(LOCALITY_WAIT * 4) + assert(manager.resourceOffer("execB.2", "host2", ANY) !== None) + sched.removeExecutor("execA") + sched.removeExecutor("execB.2") + manager.executorLost("execA", "host1") + manager.executorLost("execB.2", "host2") + clock.advance(LOCALITY_WAIT * 4) + sched.addExecutor("execC", "host3") + manager.executorAdded() + // Prior to the fix, this line resulted in an ArrayIndexOutOfBoundsException: + assert(manager.resourceOffer("execC", "host3", ANY) !== None) + } + + def createTaskResult(id: Int): DirectTaskResult[Int] = { val valueSer = SparkEnv.get.serializer.newInstance() new DirectTaskResult[Int](valueSer.serialize(id), mutable.Map.empty, new TaskMetrics) From 7e31f7c2770bd62c33d771109433b35996bf6d3c Mon Sep 17 00:00:00 2001 From: Doris Xin Date: Mon, 11 Aug 2014 19:22:14 -0700 Subject: [PATCH 098/231] [SPARK-2515][mllib] Chi Squared test Author: Doris Xin Closes #1733 from dorx/chisquare and squashes the following commits: cafb3a7 [Doris Xin] fixed p-value for extreme case. d286783 [Doris Xin] Merge branch 'master' into chisquare e95e485 [Doris Xin] reviewer comments. 7dde711 [Doris Xin] ChiSqTestResult renaming and changed to Class 80d03e2 [Doris Xin] Reviewer comments. c39eeb5 [Doris Xin] units passed with updated API e90d90a [Doris Xin] Merge branch 'master' into chisquare 7eea80b [Doris Xin] WIP d64c2fb [Doris Xin] Merge branch 'master' into chisquare 5686082 [Doris Xin] facelift bc7eb2e [Doris Xin] unit passed; still need docs and some refactoring 50703a5 [Doris Xin] merge master 4e4e361 [Doris Xin] WIP e6b83f3 [Doris Xin] reviewer comments 3d61582 [Doris Xin] input names 706d436 [Doris Xin] Added API for RDD[Vector] 6598379 [Doris Xin] API and code structure. ff17423 [Doris Xin] WIP (cherry picked from commit 32638b5e74e02410831b391f555223f90c830498) Signed-off-by: Xiangrui Meng --- .../apache/spark/mllib/stat/Statistics.scala | 64 +++++ .../spark/mllib/stat/test/ChiSqTest.scala | 221 ++++++++++++++++++ .../spark/mllib/stat/test/TestResult.scala | 88 +++++++ .../mllib/stat/HypothesisTestSuite.scala | 139 +++++++++++ 4 files changed, 512 insertions(+) create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/stat/test/TestResult.scala create mode 100644 mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala index f416a9fbb323d..cf8679610e191 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala @@ -19,7 +19,9 @@ package org.apache.spark.mllib.stat import org.apache.spark.annotation.Experimental import org.apache.spark.mllib.linalg.{Matrix, Vector} +import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.stat.correlation.Correlations +import org.apache.spark.mllib.stat.test.{ChiSqTest, ChiSqTestResult} import org.apache.spark.rdd.RDD /** @@ -89,4 +91,66 @@ object Statistics { */ @Experimental def corr(x: RDD[Double], y: RDD[Double], method: String): Double = Correlations.corr(x, y, method) + + /** + * :: Experimental :: + * Conduct Pearson's chi-squared goodness of fit test of the observed data against the + * expected distribution. + * + * Note: the two input Vectors need to have the same size. + * `observed` cannot contain negative values. + * `expected` cannot contain nonpositive values. + * + * @param observed Vector containing the observed categorical counts/relative frequencies. + * @param expected Vector containing the expected categorical counts/relative frequencies. + * `expected` is rescaled if the `expected` sum differs from the `observed` sum. + * @return ChiSquaredTest object containing the test statistic, degrees of freedom, p-value, + * the method used, and the null hypothesis. + */ + @Experimental + def chiSqTest(observed: Vector, expected: Vector): ChiSqTestResult = { + ChiSqTest.chiSquared(observed, expected) + } + + /** + * :: Experimental :: + * Conduct Pearson's chi-squared goodness of fit test of the observed data against the uniform + * distribution, with each category having an expected frequency of `1 / observed.size`. + * + * Note: `observed` cannot contain negative values. + * + * @param observed Vector containing the observed categorical counts/relative frequencies. + * @return ChiSquaredTest object containing the test statistic, degrees of freedom, p-value, + * the method used, and the null hypothesis. + */ + @Experimental + def chiSqTest(observed: Vector): ChiSqTestResult = ChiSqTest.chiSquared(observed) + + /** + * :: Experimental :: + * Conduct Pearson's independence test on the input contingency matrix, which cannot contain + * negative entries or columns or rows that sum up to 0. + * + * @param observed The contingency matrix (containing either counts or relative frequencies). + * @return ChiSquaredTest object containing the test statistic, degrees of freedom, p-value, + * the method used, and the null hypothesis. + */ + @Experimental + def chiSqTest(observed: Matrix): ChiSqTestResult = ChiSqTest.chiSquaredMatrix(observed) + + /** + * :: Experimental :: + * Conduct Pearson's independence test for every feature against the label across the input RDD. + * For each feature, the (feature, label) pairs are converted into a contingency matrix for which + * the chi-squared statistic is computed. + * + * @param data an `RDD[LabeledPoint]` containing the labeled dataset with categorical features. + * Real-valued features will be treated as categorical for each distinct value. + * @return an array containing the ChiSquaredTestResult for every feature against the label. + * The order of the elements in the returned array reflects the order of input features. + */ + @Experimental + def chiSqTest(data: RDD[LabeledPoint]): Array[ChiSqTestResult] = { + ChiSqTest.chiSquaredFeatures(data) + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala new file mode 100644 index 0000000000000..8f6752737402e --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala @@ -0,0 +1,221 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.stat.test + +import breeze.linalg.{DenseMatrix => BDM} +import cern.jet.stat.Probability.chiSquareComplemented + +import org.apache.spark.Logging +import org.apache.spark.mllib.linalg.{Matrices, Matrix, Vector, Vectors} +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.rdd.RDD + +/** + * Conduct the chi-squared test for the input RDDs using the specified method. + * Goodness-of-fit test is conducted on two `Vectors`, whereas test of independence is conducted + * on an input of type `Matrix` in which independence between columns is assessed. + * We also provide a method for computing the chi-squared statistic between each feature and the + * label for an input `RDD[LabeledPoint]`, return an `Array[ChiSquaredTestResult]` of size = + * number of features in the inpuy RDD. + * + * Supported methods for goodness of fit: `pearson` (default) + * Supported methods for independence: `pearson` (default) + * + * More information on Chi-squared test: http://en.wikipedia.org/wiki/Chi-squared_test + */ +private[stat] object ChiSqTest extends Logging { + + /** + * @param name String name for the method. + * @param chiSqFunc Function for computing the statistic given the observed and expected counts. + */ + case class Method(name: String, chiSqFunc: (Double, Double) => Double) + + // Pearson's chi-squared test: http://en.wikipedia.org/wiki/Pearson%27s_chi-squared_test + val PEARSON = new Method("pearson", (observed: Double, expected: Double) => { + val dev = observed - expected + dev * dev / expected + }) + + // Null hypothesis for the two different types of chi-squared tests to be included in the result. + object NullHypothesis extends Enumeration { + type NullHypothesis = Value + val goodnessOfFit = Value("observed follows the same distribution as expected.") + val independence = Value("observations in each column are statistically independent.") + } + + // Method identification based on input methodName string + private def methodFromString(methodName: String): Method = { + methodName match { + case PEARSON.name => PEARSON + case _ => throw new IllegalArgumentException("Unrecognized method for Chi squared test.") + } + } + + /** + * Conduct Pearson's independence test for each feature against the label across the input RDD. + * The contingency table is constructed from the raw (feature, label) pairs and used to conduct + * the independence test. + * Returns an array containing the ChiSquaredTestResult for every feature against the label. + */ + def chiSquaredFeatures(data: RDD[LabeledPoint], + methodName: String = PEARSON.name): Array[ChiSqTestResult] = { + val numCols = data.first().features.size + val results = new Array[ChiSqTestResult](numCols) + var labels: Map[Double, Int] = null + // At most 100 columns at a time + val batchSize = 100 + var batch = 0 + while (batch * batchSize < numCols) { + // The following block of code can be cleaned up and made public as + // chiSquared(data: RDD[(V1, V2)]) + val startCol = batch * batchSize + val endCol = startCol + math.min(batchSize, numCols - startCol) + val pairCounts = data.flatMap { p => + // assume dense vectors + p.features.toArray.slice(startCol, endCol).zipWithIndex.map { case (feature, col) => + (col, feature, p.label) + } + }.countByValue() + + if (labels == null) { + // Do this only once for the first column since labels are invariant across features. + labels = + pairCounts.keys.filter(_._1 == startCol).map(_._3).toArray.distinct.zipWithIndex.toMap + } + val numLabels = labels.size + pairCounts.keys.groupBy(_._1).map { case (col, keys) => + val features = keys.map(_._2).toArray.distinct.zipWithIndex.toMap + val numRows = features.size + val contingency = new BDM(numRows, numLabels, new Array[Double](numRows * numLabels)) + keys.foreach { case (_, feature, label) => + val i = features(feature) + val j = labels(label) + contingency(i, j) += pairCounts((col, feature, label)) + } + results(col) = chiSquaredMatrix(Matrices.fromBreeze(contingency), methodName) + } + batch += 1 + } + results + } + + /* + * Pearon's goodness of fit test on the input observed and expected counts/relative frequencies. + * Uniform distribution is assumed when `expected` is not passed in. + */ + def chiSquared(observed: Vector, + expected: Vector = Vectors.dense(Array[Double]()), + methodName: String = PEARSON.name): ChiSqTestResult = { + + // Validate input arguments + val method = methodFromString(methodName) + if (expected.size != 0 && observed.size != expected.size) { + throw new IllegalArgumentException("observed and expected must be of the same size.") + } + val size = observed.size + if (size > 1000) { + logWarning("Chi-squared approximation may not be accurate due to low expected frequencies " + + s" as a result of a large number of categories: $size.") + } + val obsArr = observed.toArray + val expArr = if (expected.size == 0) Array.tabulate(size)(_ => 1.0 / size) else expected.toArray + if (!obsArr.forall(_ >= 0.0)) { + throw new IllegalArgumentException("Negative entries disallowed in the observed vector.") + } + if (expected.size != 0 && ! expArr.forall(_ >= 0.0)) { + throw new IllegalArgumentException("Negative entries disallowed in the expected vector.") + } + + // Determine the scaling factor for expected + val obsSum = obsArr.sum + val expSum = if (expected.size == 0.0) 1.0 else expArr.sum + val scale = if (math.abs(obsSum - expSum) < 1e-7) 1.0 else obsSum / expSum + + // compute chi-squared statistic + val statistic = obsArr.zip(expArr).foldLeft(0.0) { case (stat, (obs, exp)) => + if (exp == 0.0) { + if (obs == 0.0) { + throw new IllegalArgumentException("Chi-squared statistic undefined for input vectors due" + + " to 0.0 values in both observed and expected.") + } else { + return new ChiSqTestResult(0.0, size - 1, Double.PositiveInfinity, PEARSON.name, + NullHypothesis.goodnessOfFit.toString) + } + } + if (scale == 1.0) { + stat + method.chiSqFunc(obs, exp) + } else { + stat + method.chiSqFunc(obs, exp * scale) + } + } + val df = size - 1 + val pValue = chiSquareComplemented(df, statistic) + new ChiSqTestResult(pValue, df, statistic, PEARSON.name, NullHypothesis.goodnessOfFit.toString) + } + + /* + * Pearon's independence test on the input contingency matrix. + * TODO: optimize for SparseMatrix when it becomes supported. + */ + def chiSquaredMatrix(counts: Matrix, methodName:String = PEARSON.name): ChiSqTestResult = { + val method = methodFromString(methodName) + val numRows = counts.numRows + val numCols = counts.numCols + + // get row and column sums + val colSums = new Array[Double](numCols) + val rowSums = new Array[Double](numRows) + val colMajorArr = counts.toArray + var i = 0 + while (i < colMajorArr.size) { + val elem = colMajorArr(i) + if (elem < 0.0) { + throw new IllegalArgumentException("Contingency table cannot contain negative entries.") + } + colSums(i / numRows) += elem + rowSums(i % numRows) += elem + i += 1 + } + val total = colSums.sum + + // second pass to collect statistic + var statistic = 0.0 + var j = 0 + while (j < colMajorArr.size) { + val col = j / numRows + val colSum = colSums(col) + if (colSum == 0.0) { + throw new IllegalArgumentException("Chi-squared statistic undefined for input matrix due to" + + s"0 sum in column [$col].") + } + val row = j % numRows + val rowSum = rowSums(row) + if (rowSum == 0.0) { + throw new IllegalArgumentException("Chi-squared statistic undefined for input matrix due to" + + s"0 sum in row [$row].") + } + val expected = colSum * rowSum / total + statistic += method.chiSqFunc(colMajorArr(j), expected) + j += 1 + } + val df = (numCols - 1) * (numRows - 1) + val pValue = chiSquareComplemented(df, statistic) + new ChiSqTestResult(pValue, df, statistic, methodName, NullHypothesis.independence.toString) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/TestResult.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/TestResult.scala new file mode 100644 index 0000000000000..2f278621335e1 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/TestResult.scala @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.stat.test + +import org.apache.spark.annotation.Experimental + +/** + * :: Experimental :: + * Trait for hypothesis test results. + * @tparam DF Return type of `degreesOfFreedom`. + */ +@Experimental +trait TestResult[DF] { + + /** + * The probability of obtaining a test statistic result at least as extreme as the one that was + * actually observed, assuming that the null hypothesis is true. + */ + def pValue: Double + + /** + * Returns the degree(s) of freedom of the hypothesis test. + * Return type should be Number(e.g. Int, Double) or tuples of Numbers for toString compatibility. + */ + def degreesOfFreedom: DF + + /** + * Test statistic. + */ + def statistic: Double + + /** + * String explaining the hypothesis test result. + * Specific classes implementing this trait should override this method to output test-specific + * information. + */ + override def toString: String = { + + // String explaining what the p-value indicates. + val pValueExplain = if (pValue <= 0.01) { + "Very strong presumption against null hypothesis." + } else if (0.01 < pValue && pValue <= 0.05) { + "Strong presumption against null hypothesis." + } else if (0.05 < pValue && pValue <= 0.01) { + "Low presumption against null hypothesis." + } else { + "No presumption against null hypothesis." + } + + s"degrees of freedom = ${degreesOfFreedom.toString} \n" + + s"statistic = $statistic \n" + + s"pValue = $pValue \n" + pValueExplain + } +} + +/** + * :: Experimental :: + * Object containing the test results for the chi squared hypothesis test. + */ +@Experimental +class ChiSqTestResult(override val pValue: Double, + override val degreesOfFreedom: Int, + override val statistic: Double, + val method: String, + val nullHypothesis: String) extends TestResult[Int] { + + override def toString: String = { + "Chi squared test summary: \n" + + s"method: $method \n" + + s"null hypothesis: $nullHypothesis \n" + + super.toString + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala new file mode 100644 index 0000000000000..5bd0521298c14 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala @@ -0,0 +1,139 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.stat + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.linalg.{DenseVector, Matrices, Vectors} +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.stat.test.ChiSqTest +import org.apache.spark.mllib.util.LocalSparkContext +import org.apache.spark.mllib.util.TestingUtils._ + +class HypothesisTestSuite extends FunSuite with LocalSparkContext { + + test("chi squared pearson goodness of fit") { + + val observed = new DenseVector(Array[Double](4, 6, 5)) + val pearson = Statistics.chiSqTest(observed) + + // Results validated against the R command `chisq.test(c(4, 6, 5), p=c(1/3, 1/3, 1/3))` + assert(pearson.statistic === 0.4) + assert(pearson.degreesOfFreedom === 2) + assert(pearson.pValue ~== 0.8187 relTol 1e-4) + assert(pearson.method === ChiSqTest.PEARSON.name) + assert(pearson.nullHypothesis === ChiSqTest.NullHypothesis.goodnessOfFit.toString) + + // different expected and observed sum + val observed1 = new DenseVector(Array[Double](21, 38, 43, 80)) + val expected1 = new DenseVector(Array[Double](3, 5, 7, 20)) + val pearson1 = Statistics.chiSqTest(observed1, expected1) + + // Results validated against the R command + // `chisq.test(c(21, 38, 43, 80), p=c(3/35, 1/7, 1/5, 4/7))` + assert(pearson1.statistic ~== 14.1429 relTol 1e-4) + assert(pearson1.degreesOfFreedom === 3) + assert(pearson1.pValue ~== 0.002717 relTol 1e-4) + assert(pearson1.method === ChiSqTest.PEARSON.name) + assert(pearson1.nullHypothesis === ChiSqTest.NullHypothesis.goodnessOfFit.toString) + + // Vectors with different sizes + val observed3 = new DenseVector(Array(1.0, 2.0, 3.0)) + val expected3 = new DenseVector(Array(1.0, 2.0, 3.0, 4.0)) + intercept[IllegalArgumentException](Statistics.chiSqTest(observed3, expected3)) + + // negative counts in observed + val negObs = new DenseVector(Array(1.0, 2.0, 3.0, -4.0)) + intercept[IllegalArgumentException](Statistics.chiSqTest(negObs, expected1)) + + // count = 0.0 in expected but not observed + val zeroExpected = new DenseVector(Array(1.0, 0.0, 3.0)) + val inf = Statistics.chiSqTest(observed, zeroExpected) + assert(inf.statistic === Double.PositiveInfinity) + assert(inf.degreesOfFreedom === 2) + assert(inf.pValue === 0.0) + assert(inf.method === ChiSqTest.PEARSON.name) + assert(inf.nullHypothesis === ChiSqTest.NullHypothesis.goodnessOfFit.toString) + + // 0.0 in expected and observed simultaneously + val zeroObserved = new DenseVector(Array(2.0, 0.0, 1.0)) + intercept[IllegalArgumentException](Statistics.chiSqTest(zeroObserved, zeroExpected)) + } + + test("chi squared pearson matrix independence") { + val data = Array(40.0, 24.0, 29.0, 56.0, 32.0, 42.0, 31.0, 10.0, 0.0, 30.0, 15.0, 12.0) + // [[40.0, 56.0, 31.0, 30.0], + // [24.0, 32.0, 10.0, 15.0], + // [29.0, 42.0, 0.0, 12.0]] + val chi = Statistics.chiSqTest(Matrices.dense(3, 4, data)) + // Results validated against R command + // `chisq.test(rbind(c(40, 56, 31, 30),c(24, 32, 10, 15), c(29, 42, 0, 12)))` + assert(chi.statistic ~== 21.9958 relTol 1e-4) + assert(chi.degreesOfFreedom === 6) + assert(chi.pValue ~== 0.001213 relTol 1e-4) + assert(chi.method === ChiSqTest.PEARSON.name) + assert(chi.nullHypothesis === ChiSqTest.NullHypothesis.independence.toString) + + // Negative counts + val negCounts = Array(4.0, 5.0, 3.0, -3.0) + intercept[IllegalArgumentException](Statistics.chiSqTest(Matrices.dense(2, 2, negCounts))) + + // Row sum = 0.0 + val rowZero = Array(0.0, 1.0, 0.0, 2.0) + intercept[IllegalArgumentException](Statistics.chiSqTest(Matrices.dense(2, 2, rowZero))) + + // Column sum = 0.0 + val colZero = Array(0.0, 0.0, 2.0, 2.0) + // IllegalArgumentException thrown here since it's thrown on driver, not inside a task + intercept[IllegalArgumentException](Statistics.chiSqTest(Matrices.dense(2, 2, colZero))) + } + + test("chi squared pearson RDD[LabeledPoint]") { + // labels: 1.0 (2 / 6), 0.0 (4 / 6) + // feature1: 0.5 (1 / 6), 1.5 (2 / 6), 3.5 (3 / 6) + // feature2: 10.0 (1 / 6), 20.0 (1 / 6), 30.0 (2 / 6), 40.0 (2 / 6) + val data = Array(new LabeledPoint(0.0, Vectors.dense(0.5, 10.0)), + new LabeledPoint(0.0, Vectors.dense(1.5, 20.0)), + new LabeledPoint(1.0, Vectors.dense(1.5, 30.0)), + new LabeledPoint(0.0, Vectors.dense(3.5, 30.0)), + new LabeledPoint(0.0, Vectors.dense(3.5, 40.0)), + new LabeledPoint(1.0, Vectors.dense(3.5, 40.0))) + for (numParts <- List(2, 4, 6, 8)) { + val chi = Statistics.chiSqTest(sc.parallelize(data, numParts)) + val feature1 = chi(0) + assert(feature1.statistic === 0.75) + assert(feature1.degreesOfFreedom === 2) + assert(feature1.pValue ~== 0.6873 relTol 1e-4) + assert(feature1.method === ChiSqTest.PEARSON.name) + assert(feature1.nullHypothesis === ChiSqTest.NullHypothesis.independence.toString) + val feature2 = chi(1) + assert(feature2.statistic === 1.5) + assert(feature2.degreesOfFreedom === 3) + assert(feature2.pValue ~== 0.6823 relTol 1e-4) + assert(feature2.method === ChiSqTest.PEARSON.name) + assert(feature2.nullHypothesis === ChiSqTest.NullHypothesis.independence.toString) + } + + // Test that the right number of results is returned + val numCols = 321 + val sparseData = Array(new LabeledPoint(0.0, Vectors.sparse(numCols, Seq((100, 2.0)))), + new LabeledPoint(0.0, Vectors.sparse(numCols, Seq((200, 1.0))))) + val chi = Statistics.chiSqTest(sc.parallelize(sparseData)) + assert(chi.size === numCols) + } +} From 8f6e2e9df41e7de22b1d1cbd524e20881f861dd0 Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Mon, 11 Aug 2014 19:49:29 -0700 Subject: [PATCH 099/231] [SPARK-2934][MLlib] Adding LogisticRegressionWithLBFGS Interface for training with LBFGS Optimizer which will converge faster than SGD. Author: DB Tsai Closes #1862 from dbtsai/dbtsai-lbfgs-lor and squashes the following commits: aa84b81 [DB Tsai] small change f852bcd [DB Tsai] Remove duplicate method f119fdc [DB Tsai] Formatting 97776aa [DB Tsai] address more feedback 85b4a91 [DB Tsai] address feedback 3cf50c2 [DB Tsai] LogisticRegressionWithLBFGS interface (cherry picked from commit 6fab941b65f0cb6c9b32e0f8290d76889cda6a87) Signed-off-by: Xiangrui Meng --- .../classification/LogisticRegression.scala | 51 ++++++++++- .../LogisticRegressionSuite.scala | 89 ++++++++++++++++++- 2 files changed, 136 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala index 2242329b7918e..31d474a20fa85 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala @@ -101,7 +101,7 @@ class LogisticRegressionWithSGD private ( } /** - * Top-level methods for calling Logistic Regression. + * Top-level methods for calling Logistic Regression using Stochastic Gradient Descent. * NOTE: Labels used in Logistic Regression should be {0, 1} */ object LogisticRegressionWithSGD { @@ -188,3 +188,52 @@ object LogisticRegressionWithSGD { train(input, numIterations, 1.0, 1.0) } } + +/** + * Train a classification model for Logistic Regression using Limited-memory BFGS. + * NOTE: Labels used in Logistic Regression should be {0, 1} + */ +class LogisticRegressionWithLBFGS private ( + private var convergenceTol: Double, + private var maxNumIterations: Int, + private var regParam: Double) + extends GeneralizedLinearAlgorithm[LogisticRegressionModel] with Serializable { + + /** + * Construct a LogisticRegression object with default parameters + */ + def this() = this(1E-4, 100, 0.0) + + private val gradient = new LogisticGradient() + private val updater = new SimpleUpdater() + // Have to return new LBFGS object every time since users can reset the parameters anytime. + override def optimizer = new LBFGS(gradient, updater) + .setNumCorrections(10) + .setConvergenceTol(convergenceTol) + .setMaxNumIterations(maxNumIterations) + .setRegParam(regParam) + + override protected val validators = List(DataValidators.binaryLabelValidator) + + /** + * Set the convergence tolerance of iterations for L-BFGS. Default 1E-4. + * Smaller value will lead to higher accuracy with the cost of more iterations. + */ + def setConvergenceTol(convergenceTol: Double): this.type = { + this.convergenceTol = convergenceTol + this + } + + /** + * Set the maximal number of iterations for L-BFGS. Default 100. + */ + def setNumIterations(numIterations: Int): this.type = { + this.maxNumIterations = numIterations + this + } + + override protected def createModel(weights: Vector, intercept: Double) = { + new LogisticRegressionModel(weights, intercept) + } + +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala index da7c633bbd2af..2289c6cdc19de 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala @@ -67,7 +67,7 @@ class LogisticRegressionSuite extends FunSuite with LocalSparkContext with Match } // Test if we can correctly learn A, B where Y = logistic(A + B*X) - test("logistic regression") { + test("logistic regression with SGD") { val nPoints = 10000 val A = 2.0 val B = -1.5 @@ -94,7 +94,36 @@ class LogisticRegressionSuite extends FunSuite with LocalSparkContext with Match validatePrediction(validationData.map(row => model.predict(row.features)), validationData) } - test("logistic regression with initial weights") { + // Test if we can correctly learn A, B where Y = logistic(A + B*X) + test("logistic regression with LBFGS") { + val nPoints = 10000 + val A = 2.0 + val B = -1.5 + + val testData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 42) + + val testRDD = sc.parallelize(testData, 2) + testRDD.cache() + val lr = new LogisticRegressionWithLBFGS().setIntercept(true) + + val model = lr.run(testRDD) + + // Test the weights + assert(model.weights(0) ~== -1.52 relTol 0.01) + assert(model.intercept ~== 2.00 relTol 0.01) + assert(model.weights(0) ~== model.weights(0) relTol 0.01) + assert(model.intercept ~== model.intercept relTol 0.01) + + val validationData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 17) + val validationRDD = sc.parallelize(validationData, 2) + // Test prediction on RDD. + validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData) + + // Test prediction on Array. + validatePrediction(validationData.map(row => model.predict(row.features)), validationData) + } + + test("logistic regression with initial weights with SGD") { val nPoints = 10000 val A = 2.0 val B = -1.5 @@ -125,11 +154,42 @@ class LogisticRegressionSuite extends FunSuite with LocalSparkContext with Match // Test prediction on Array. validatePrediction(validationData.map(row => model.predict(row.features)), validationData) } + + test("logistic regression with initial weights with LBFGS") { + val nPoints = 10000 + val A = 2.0 + val B = -1.5 + + val testData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 42) + + val initialB = -1.0 + val initialWeights = Vectors.dense(initialB) + + val testRDD = sc.parallelize(testData, 2) + testRDD.cache() + + // Use half as many iterations as the previous test. + val lr = new LogisticRegressionWithLBFGS().setIntercept(true) + + val model = lr.run(testRDD, initialWeights) + + // Test the weights + assert(model.weights(0) ~== -1.50 relTol 0.02) + assert(model.intercept ~== 1.97 relTol 0.02) + + val validationData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 17) + val validationRDD = sc.parallelize(validationData, 2) + // Test prediction on RDD. + validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData) + + // Test prediction on Array. + validatePrediction(validationData.map(row => model.predict(row.features)), validationData) + } } class LogisticRegressionClusterSuite extends FunSuite with LocalClusterSparkContext { - test("task size should be small in both training and prediction") { + test("task size should be small in both training and prediction using SGD optimizer") { val m = 4 val n = 200000 val points = sc.parallelize(0 until m, 2).mapPartitionsWithIndex { (idx, iter) => @@ -139,6 +199,29 @@ class LogisticRegressionClusterSuite extends FunSuite with LocalClusterSparkCont // If we serialize data directly in the task closure, the size of the serialized task would be // greater than 1MB and hence Spark would throw an error. val model = LogisticRegressionWithSGD.train(points, 2) + val predictions = model.predict(points.map(_.features)) + + // Materialize the RDDs + predictions.count() } + + test("task size should be small in both training and prediction using LBFGS optimizer") { + val m = 4 + val n = 200000 + val points = sc.parallelize(0 until m, 2).mapPartitionsWithIndex { (idx, iter) => + val random = new Random(idx) + iter.map(i => LabeledPoint(1.0, Vectors.dense(Array.fill(n)(random.nextDouble())))) + }.cache() + // If we serialize data directly in the task closure, the size of the serialized task would be + // greater than 1MB and hence Spark would throw an error. + val model = + (new LogisticRegressionWithLBFGS().setIntercept(true).setNumIterations(2)).run(points) + + val predictions = model.predict(points.map(_.features)) + + // Materialize the RDDs + predictions.count() + } + } From 8cb4e5b47b9b871bf4c0d93d0a747e55f66ca0ec Mon Sep 17 00:00:00 2001 From: Ahir Reddy Date: Mon, 11 Aug 2014 20:06:06 -0700 Subject: [PATCH 100/231] [SPARK-2844][SQL] Correctly set JVM HiveContext if it is passed into Python HiveContext constructor https://issues.apache.org/jira/browse/SPARK-2844 Author: Ahir Reddy Closes #1768 from ahirreddy/python-hive-context-fix and squashes the following commits: 7972d3b [Ahir Reddy] Correctly set JVM HiveContext if it is passed into Python HiveContext constructor (cherry picked from commit 490ecfa20327a636289321ea447722aa32b81657) Signed-off-by: Michael Armbrust --- python/pyspark/sql.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 950e275adbf01..36040463e62a9 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -912,6 +912,8 @@ def __init__(self, sparkContext, sqlContext=None): """Create a new SQLContext. @param sparkContext: The SparkContext to wrap. + @param sqlContext: An optional JVM Scala SQLContext. If set, we do not instatiate a new + SQLContext in the JVM, instead we make all calls to this object. >>> srdd = sqlCtx.inferSchema(rdd) >>> sqlCtx.inferSchema(srdd) # doctest: +IGNORE_EXCEPTION_DETAIL @@ -1315,6 +1317,18 @@ class HiveContext(SQLContext): It supports running both SQL and HiveQL commands. """ + def __init__(self, sparkContext, hiveContext=None): + """Create a new HiveContext. + + @param sparkContext: The SparkContext to wrap. + @param hiveContext: An optional JVM Scala HiveContext. If set, we do not instatiate a new + HiveContext in the JVM, instead we make all calls to this object. + """ + SQLContext.__init__(self, sparkContext) + + if hiveContext: + self._scala_HiveContext = hiveContext + @property def _ssql_ctx(self): try: From cf2f8071db567a3d795782ffa95d9d4b5dd6acdb Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Mon, 11 Aug 2014 20:08:06 -0700 Subject: [PATCH 101/231] [SPARK-2590][SQL] Added option to handle incremental collection, disabled by default JIRA issue: [SPARK-2590](https://issues.apache.org/jira/browse/SPARK-2590) Author: Cheng Lian Closes #1853 from liancheng/inc-collect-option and squashes the following commits: cb3ea45 [Cheng Lian] Moved incremental collection option to Thrift server 43ce3aa [Cheng Lian] Changed incremental collect option name 623abde [Cheng Lian] Added option to handle incremental collection, disabled by default (cherry picked from commit 21a95ef051f7b23a80d147aadb00dfa4ebb169b0) Signed-off-by: Michael Armbrust --- .../server/SparkSQLOperationManager.scala | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala index dee092159dd4c..f192f490ac3d0 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala @@ -132,7 +132,16 @@ class SparkSQLOperationManager(hiveContext: HiveContext) extends OperationManage logDebug(result.queryExecution.toString()) val groupId = round(random * 1000000).toString hiveContext.sparkContext.setJobGroup(groupId, statement) - iter = result.queryExecution.toRdd.toLocalIterator + iter = { + val resultRdd = result.queryExecution.toRdd + val useIncrementalCollect = + hiveContext.getConf("spark.sql.thriftServer.incrementalCollect", "false").toBoolean + if (useIncrementalCollect) { + resultRdd.toLocalIterator + } else { + resultRdd.collect().iterator + } + } dataTypes = result.queryExecution.analyzed.output.map(_.dataType).toArray setHasResultSet(true) } catch { From 6d0af526cf3fccdd668adcc20407a72764affdd6 Mon Sep 17 00:00:00 2001 From: wangfei Date: Mon, 11 Aug 2014 20:10:13 -0700 Subject: [PATCH 102/231] [sql]use SparkSQLEnv.stop() in ShutdownHook Author: wangfei Closes #1852 from scwf/patch-3 and squashes the following commits: ae28c29 [wangfei] use SparkSQLEnv.stop() in ShutdownHook (cherry picked from commit e83fdcd421d132812411eb805565b76f087f1bc0) Signed-off-by: Michael Armbrust --- .../apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala index 6f7942aba314a..cadf7aaf42157 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala @@ -60,7 +60,7 @@ private[hive] object HiveThriftServer2 extends Logging { Runtime.getRuntime.addShutdownHook( new Thread() { override def run() { - SparkSQLEnv.sparkContext.stop() + SparkSQLEnv.stop() } } ) From fd8173fac6ae0ef329085e47887535c0607a9a8d Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Mon, 11 Aug 2014 20:11:29 -0700 Subject: [PATCH 103/231] [SQL] A tiny refactoring in HiveContext#analyze I should use `EliminateAnalysisOperators` in `analyze` instead of manually pattern matching. Author: Yin Huai Closes #1881 from yhuai/useEliminateAnalysisOperators and squashes the following commits: f3e1e7f [Yin Huai] Use EliminateAnalysisOperators. (cherry picked from commit 647aeba3a9e101d35083f7c4afbcfe7a33f7fc62) Signed-off-by: Michael Armbrust --- .../scala/org/apache/spark/sql/hive/HiveContext.scala | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 53f3dc11dbb9f..a8da676ffa0e0 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -39,7 +39,8 @@ import org.apache.spark.SparkContext import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.ScalaReflection -import org.apache.spark.sql.catalyst.analysis.{OverrideFunctionRegistry, Analyzer, OverrideCatalog} +import org.apache.spark.sql.catalyst.analysis.{Analyzer, EliminateAnalysisOperators} +import org.apache.spark.sql.catalyst.analysis.{OverrideCatalog, OverrideFunctionRegistry} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.ExtractPythonUdfs import org.apache.spark.sql.execution.QueryExecutionException @@ -119,10 +120,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { * in the Hive metastore. */ def analyze(tableName: String) { - val relation = catalog.lookupRelation(None, tableName) match { - case LowerCaseSchema(r) => r - case o => o - } + val relation = EliminateAnalysisOperators(catalog.lookupRelation(None, tableName)) relation match { case relation: MetastoreRelation => { From dcbf079f626c9ef8ab79de60acd817b7bbc5f20d Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Mon, 11 Aug 2014 20:15:01 -0700 Subject: [PATCH 104/231] [SPARK-2965][SQL] Fix HashOuterJoin output nullabilities. Output attributes of opposite side of `OuterJoin` should be nullable. Author: Takuya UESHIN Closes #1887 from ueshin/issues/SPARK-2965 and squashes the following commits: bcb2d37 [Takuya UESHIN] Fix HashOuterJoin output nullabilities. (cherry picked from commit c9c89c31b6114832fe282c21fecd663d8105b9bc) Signed-off-by: Michael Armbrust --- .../org/apache/spark/sql/execution/joins.scala | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala index 51bb61530744c..ea075f8c65bff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala @@ -168,7 +168,18 @@ case class HashOuterJoin( override def requiredChildDistribution = ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil - def output = left.output ++ right.output + override def output = { + joinType match { + case LeftOuter => + left.output ++ right.output.map(_.withNullability(true)) + case RightOuter => + left.output.map(_.withNullability(true)) ++ right.output + case FullOuter => + left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true)) + case x => + throw new Exception(s"HashOuterJoin should not take $x as the JoinType") + } + } // TODO we need to rewrite all of the iterators with our own implementation instead of the Scala // iterator for performance purpose. From 54b387f0fa1f57480a7456db138c2e44b5d2c815 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Mon, 11 Aug 2014 20:18:03 -0700 Subject: [PATCH 105/231] [SPARK-2968][SQL] Fix nullabilities of Explode. Output nullabilities of `Explode` could be detemined by `ArrayType.containsNull` or `MapType.valueContainsNull`. Author: Takuya UESHIN Closes #1888 from ueshin/issues/SPARK-2968 and squashes the following commits: d128c95 [Takuya UESHIN] Fix nullability of Explode. (cherry picked from commit c686b7dd4668b5e9fc3177f15edeae3446d2e634) Signed-off-by: Michael Armbrust --- .../spark/sql/catalyst/expressions/generators.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index 3d41acb79e5fd..e99c5b452d183 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -86,19 +86,19 @@ case class Explode(attributeNames: Seq[String], child: Expression) (child.dataType.isInstanceOf[ArrayType] || child.dataType.isInstanceOf[MapType]) private lazy val elementTypes = child.dataType match { - case ArrayType(et, _) => et :: Nil - case MapType(kt,vt, _) => kt :: vt :: Nil + case ArrayType(et, containsNull) => (et, containsNull) :: Nil + case MapType(kt, vt, valueContainsNull) => (kt, false) :: (vt, valueContainsNull) :: Nil } // TODO: Move this pattern into Generator. protected def makeOutput() = if (attributeNames.size == elementTypes.size) { attributeNames.zip(elementTypes).map { - case (n, t) => AttributeReference(n, t, nullable = true)() + case (n, (t, nullable)) => AttributeReference(n, t, nullable)() } } else { elementTypes.zipWithIndex.map { - case (t, i) => AttributeReference(s"c_$i", t, nullable = true)() + case ((t, nullable), i) => AttributeReference(s"c_$i", t, nullable)() } } From 779d1eb26d0f031791e93c908d51a59c3b422a55 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Mon, 11 Aug 2014 20:21:56 -0700 Subject: [PATCH 106/231] [SPARK-2650][SQL] Build column buffers in smaller batches Author: Michael Armbrust Closes #1880 from marmbrus/columnBatches and squashes the following commits: 0649987 [Michael Armbrust] add test 4756fad [Michael Armbrust] fix compilation 2314532 [Michael Armbrust] Build column buffers in smaller batches (cherry picked from commit bad21ed085a505559dccc06223b486170371ddd2) Signed-off-by: Michael Armbrust --- .../scala/org/apache/spark/sql/SQLConf.scala | 4 + .../org/apache/spark/sql/SQLContext.scala | 4 +- .../columnar/InMemoryColumnarTableScan.scala | 76 ++++++++++++------- .../apache/spark/sql/CachedTableSuite.scala | 12 ++- .../columnar/InMemoryColumnarQuerySuite.scala | 6 +- .../spark/sql/hive/HiveMetastoreCatalog.scala | 2 +- .../spark/sql/hive/HiveStrategies.scala | 2 +- 7 files changed, 70 insertions(+), 36 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 0fd7aaaa36eb8..35c51dec0bcf5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -25,6 +25,7 @@ import java.util.Properties private[spark] object SQLConf { val COMPRESS_CACHED = "spark.sql.inMemoryColumnarStorage.compressed" + val COLUMN_BATCH_SIZE = "spark.sql.inMemoryColumnarStorage.batchSize" val AUTO_BROADCASTJOIN_THRESHOLD = "spark.sql.autoBroadcastJoinThreshold" val DEFAULT_SIZE_IN_BYTES = "spark.sql.defaultSizeInBytes" val SHUFFLE_PARTITIONS = "spark.sql.shuffle.partitions" @@ -71,6 +72,9 @@ trait SQLConf { /** When true tables cached using the in-memory columnar caching will be compressed. */ private[spark] def useCompression: Boolean = getConf(COMPRESS_CACHED, "false").toBoolean + /** The number of rows that will be */ + private[spark] def columnBatchSize: Int = getConf(COLUMN_BATCH_SIZE, "1000").toInt + /** Number of partitions to use for shuffle operators. */ private[spark] def numShufflePartitions: Int = getConf(SHUFFLE_PARTITIONS, "200").toInt diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 71d338d21d0f2..af9f7c62a1d25 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -273,7 +273,7 @@ class SQLContext(@transient val sparkContext: SparkContext) currentTable.logicalPlan case _ => - InMemoryRelation(useCompression, executePlan(currentTable).executedPlan) + InMemoryRelation(useCompression, columnBatchSize, executePlan(currentTable).executedPlan) } catalog.registerTable(None, tableName, asInMemoryRelation) @@ -284,7 +284,7 @@ class SQLContext(@transient val sparkContext: SparkContext) table(tableName).queryExecution.analyzed match { // This is kind of a hack to make sure that if this was just an RDD registered as a table, // we reregister the RDD as a table. - case inMem @ InMemoryRelation(_, _, e: ExistingRdd) => + case inMem @ InMemoryRelation(_, _, _, e: ExistingRdd) => inMem.cachedColumnBuffers.unpersist() catalog.unregisterTable(None, tableName) catalog.registerTable(None, tableName, SparkLogicalPlan(e)(self)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala index 88901debbb4e9..3364d0e18bcc9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala @@ -28,13 +28,14 @@ import org.apache.spark.sql.Row import org.apache.spark.SparkConf object InMemoryRelation { - def apply(useCompression: Boolean, child: SparkPlan): InMemoryRelation = - new InMemoryRelation(child.output, useCompression, child)() + def apply(useCompression: Boolean, batchSize: Int, child: SparkPlan): InMemoryRelation = + new InMemoryRelation(child.output, useCompression, batchSize, child)() } private[sql] case class InMemoryRelation( output: Seq[Attribute], useCompression: Boolean, + batchSize: Int, child: SparkPlan) (private var _cachedColumnBuffers: RDD[Array[ByteBuffer]] = null) extends LogicalPlan with MultiInstanceRelation { @@ -43,22 +44,31 @@ private[sql] case class InMemoryRelation( // As in Spark, the actual work of caching is lazy. if (_cachedColumnBuffers == null) { val output = child.output - val cached = child.execute().mapPartitions { iterator => - val columnBuilders = output.map { attribute => - ColumnBuilder(ColumnType(attribute.dataType).typeId, 0, attribute.name, useCompression) - }.toArray - - var row: Row = null - while (iterator.hasNext) { - row = iterator.next() - var i = 0 - while (i < row.length) { - columnBuilders(i).appendFrom(row, i) - i += 1 + val cached = child.execute().mapPartitions { baseIterator => + new Iterator[Array[ByteBuffer]] { + def next() = { + val columnBuilders = output.map { attribute => + ColumnBuilder(ColumnType(attribute.dataType).typeId, 0, attribute.name, useCompression) + }.toArray + + var row: Row = null + var rowCount = 0 + + while (baseIterator.hasNext && rowCount < batchSize) { + row = baseIterator.next() + var i = 0 + while (i < row.length) { + columnBuilders(i).appendFrom(row, i) + i += 1 + } + rowCount += 1 + } + + columnBuilders.map(_.build()) } - } - Iterator.single(columnBuilders.map(_.build())) + def hasNext = baseIterator.hasNext + } }.cache() cached.setName(child.toString) @@ -74,6 +84,7 @@ private[sql] case class InMemoryRelation( new InMemoryRelation( output.map(_.newInstance), useCompression, + batchSize, child)( _cachedColumnBuffers).asInstanceOf[this.type] } @@ -90,22 +101,31 @@ private[sql] case class InMemoryColumnarTableScan( override def execute() = { relation.cachedColumnBuffers.mapPartitions { iterator => - val columnBuffers = iterator.next() - assert(!iterator.hasNext) + // Find the ordinals of the requested columns. If none are requested, use the first. + val requestedColumns = + if (attributes.isEmpty) { + Seq(0) + } else { + attributes.map(a => relation.output.indexWhere(_.exprId == a.exprId)) + } new Iterator[Row] { - // Find the ordinals of the requested columns. If none are requested, use the first. - val requestedColumns = - if (attributes.isEmpty) { - Seq(0) - } else { - attributes.map(a => relation.output.indexWhere(_.exprId == a.exprId)) - } + private[this] var columnBuffers: Array[ByteBuffer] = null + private[this] var columnAccessors: Seq[ColumnAccessor] = null + nextBatch() + + private[this] val nextRow = new GenericMutableRow(columnAccessors.length) - val columnAccessors = requestedColumns.map(columnBuffers(_)).map(ColumnAccessor(_)) - val nextRow = new GenericMutableRow(columnAccessors.length) + def nextBatch() = { + columnBuffers = iterator.next() + columnAccessors = requestedColumns.map(columnBuffers(_)).map(ColumnAccessor(_)) + } override def next() = { + if (!columnAccessors.head.hasNext) { + nextBatch() + } + var i = 0 while (i < nextRow.length) { columnAccessors(i).extractTo(nextRow, i) @@ -114,7 +134,7 @@ private[sql] case class InMemoryColumnarTableScan( nextRow } - override def hasNext = columnAccessors.head.hasNext + override def hasNext = columnAccessors.head.hasNext || iterator.hasNext } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index fbf9bd9dbcdea..befef46d93973 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -22,9 +22,19 @@ import org.apache.spark.sql.columnar.{InMemoryRelation, InMemoryColumnarTableSca import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext._ +case class BigData(s: String) + class CachedTableSuite extends QueryTest { TestData // Load test tables. + test("too big for memory") { + val data = "*" * 10000 + sparkContext.parallelize(1 to 1000000, 1).map(_ => BigData(data)).registerTempTable("bigData") + cacheTable("bigData") + assert(table("bigData").count() === 1000000L) + uncacheTable("bigData") + } + test("SPARK-1669: cacheTable should be idempotent") { assume(!table("testData").logicalPlan.isInstanceOf[InMemoryRelation]) @@ -37,7 +47,7 @@ class CachedTableSuite extends QueryTest { cacheTable("testData") table("testData").queryExecution.analyzed match { - case InMemoryRelation(_, _, _: InMemoryColumnarTableScan) => + case InMemoryRelation(_, _, _, _: InMemoryColumnarTableScan) => fail("cacheTable is not idempotent") case _ => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala index b561b44ad7ee2..736c0f8571e9e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala @@ -28,14 +28,14 @@ class InMemoryColumnarQuerySuite extends QueryTest { test("simple columnar query") { val plan = TestSQLContext.executePlan(testData.logicalPlan).executedPlan - val scan = InMemoryRelation(useCompression = true, plan) + val scan = InMemoryRelation(useCompression = true, 5, plan) checkAnswer(scan, testData.collect().toSeq) } test("projection") { val plan = TestSQLContext.executePlan(testData.select('value, 'key).logicalPlan).executedPlan - val scan = InMemoryRelation(useCompression = true, plan) + val scan = InMemoryRelation(useCompression = true, 5, plan) checkAnswer(scan, testData.collect().map { case Row(key: Int, value: String) => value -> key @@ -44,7 +44,7 @@ class InMemoryColumnarQuerySuite extends QueryTest { test("SPARK-1436 regression: in-memory columns must be able to be accessed multiple times") { val plan = TestSQLContext.executePlan(testData.logicalPlan).executedPlan - val scan = InMemoryRelation(useCompression = true, plan) + val scan = InMemoryRelation(useCompression = true, 5, plan) checkAnswer(scan, testData.collect().toSeq) checkAnswer(scan, testData.collect().toSeq) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 82e9c1a248626..3b371211e14cd 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -137,7 +137,7 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with castChildOutput(p, table, child) case p @ logical.InsertIntoTable( - InMemoryRelation(_, _, + InMemoryRelation(_, _, _, HiveTableScan(_, table, _)), _, child, _) => castChildOutput(p, table, child) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index 85d2496a34cfb..5fcc1bd4b9adf 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -45,7 +45,7 @@ private[hive] trait HiveStrategies { case logical.InsertIntoTable(table: MetastoreRelation, partition, child, overwrite) => InsertIntoHiveTable(table, partition, planLater(child), overwrite)(hiveContext) :: Nil case logical.InsertIntoTable( - InMemoryRelation(_, _, + InMemoryRelation(_, _, _, HiveTableScan(_, table, _)), partition, child, overwrite) => InsertIntoHiveTable(table, partition, planLater(child), overwrite)(hiveContext) :: Nil case _ => Nil From f66f260bbede1eb4e4133918812700baa252fba8 Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Mon, 11 Aug 2014 20:45:14 -0700 Subject: [PATCH 107/231] [SQL] [SPARK-2826] Reduce the memory copy while building the hashmap for HashOuterJoin This is a follow up for #1147 , this PR will improve the performance about 10% - 15% in my local tests. ``` Before: LeftOuterJoin: took 16750 ms ([3000000] records) LeftOuterJoin: took 15179 ms ([3000000] records) RightOuterJoin: took 15515 ms ([3000000] records) RightOuterJoin: took 15276 ms ([3000000] records) FullOuterJoin: took 19150 ms ([6000000] records) FullOuterJoin: took 18935 ms ([6000000] records) After: LeftOuterJoin: took 15218 ms ([3000000] records) LeftOuterJoin: took 13503 ms ([3000000] records) RightOuterJoin: took 13663 ms ([3000000] records) RightOuterJoin: took 14025 ms ([3000000] records) FullOuterJoin: took 16624 ms ([6000000] records) FullOuterJoin: took 16578 ms ([6000000] records) ``` Besides the performance improvement, I also do some clean up as suggested in #1147 Author: Cheng Hao Closes #1765 from chenghao-intel/hash_outer_join_fixing and squashes the following commits: ab1f9e0 [Cheng Hao] Reduce the memory copy while building the hashmap (cherry picked from commit 5d54d71ddbac1fbb26925a8c9138bbb8c0e81db8) Signed-off-by: Michael Armbrust --- .../apache/spark/sql/execution/joins.scala | 54 ++++++++++--------- 1 file changed, 28 insertions(+), 26 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala index ea075f8c65bff..c86811e838bd8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution +import java.util.{HashMap => JavaHashMap} + import scala.collection.mutable.{ArrayBuffer, BitSet} import scala.concurrent.ExecutionContext.Implicits.global import scala.concurrent._ @@ -136,14 +138,6 @@ trait HashJoin { } } -/** - * Constant Value for Binary Join Node - */ -object HashOuterJoin { - val DUMMY_LIST = Seq[Row](null) - val EMPTY_LIST = Seq[Row]() -} - /** * :: DeveloperApi :: * Performs a hash based outer join for two child relations by shuffling the data using @@ -181,6 +175,9 @@ case class HashOuterJoin( } } + @transient private[this] lazy val DUMMY_LIST = Seq[Row](null) + @transient private[this] lazy val EMPTY_LIST = Seq.empty[Row] + // TODO we need to rewrite all of the iterators with our own implementation instead of the Scala // iterator for performance purpose. @@ -199,8 +196,8 @@ case class HashOuterJoin( joinedRow.copy } else { Nil - }) ++ HashOuterJoin.DUMMY_LIST.filter(_ => !matched).map( _ => { - // HashOuterJoin.DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row, + }) ++ DUMMY_LIST.filter(_ => !matched).map( _ => { + // DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row, // as we don't know whether we need to append it until finish iterating all of the // records in right side. // If we didn't get any proper row, then append a single row with empty right @@ -224,8 +221,8 @@ case class HashOuterJoin( joinedRow.copy } else { Nil - }) ++ HashOuterJoin.DUMMY_LIST.filter(_ => !matched).map( _ => { - // HashOuterJoin.DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row, + }) ++ DUMMY_LIST.filter(_ => !matched).map( _ => { + // DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row, // as we don't know whether we need to append it until finish iterating all of the // records in left side. // If we didn't get any proper row, then append a single row with empty left. @@ -259,10 +256,10 @@ case class HashOuterJoin( rightMatchedSet.add(idx) joinedRow.copy } - } ++ HashOuterJoin.DUMMY_LIST.filter(_ => !matched).map( _ => { + } ++ DUMMY_LIST.filter(_ => !matched).map( _ => { // 2. For those unmatched records in left, append additional records with empty right. - // HashOuterJoin.DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row, + // DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row, // as we don't know whether we need to append it until finish iterating all // of the records in right side. // If we didn't get any proper row, then append a single row with empty right. @@ -287,18 +284,22 @@ case class HashOuterJoin( } private[this] def buildHashTable( - iter: Iterator[Row], keyGenerator: Projection): Map[Row, ArrayBuffer[Row]] = { - // TODO: Use Spark's HashMap implementation. - val hashTable = scala.collection.mutable.Map[Row, ArrayBuffer[Row]]() + iter: Iterator[Row], keyGenerator: Projection): JavaHashMap[Row, ArrayBuffer[Row]] = { + val hashTable = new JavaHashMap[Row, ArrayBuffer[Row]]() while (iter.hasNext) { val currentRow = iter.next() val rowKey = keyGenerator(currentRow) - val existingMatchList = hashTable.getOrElseUpdate(rowKey, {new ArrayBuffer[Row]()}) + var existingMatchList = hashTable.get(rowKey) + if (existingMatchList == null) { + existingMatchList = new ArrayBuffer[Row]() + hashTable.put(rowKey, existingMatchList) + } + existingMatchList += currentRow.copy() } - - hashTable.toMap[Row, ArrayBuffer[Row]] + + hashTable } def execute() = { @@ -309,21 +310,22 @@ case class HashOuterJoin( // Build HashMap for current partition in right relation val rightHashTable = buildHashTable(rightIter, newProjection(rightKeys, right.output)) + import scala.collection.JavaConversions._ val boundCondition = condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true) joinType match { case LeftOuter => leftHashTable.keysIterator.flatMap { key => - leftOuterIterator(key, leftHashTable.getOrElse(key, HashOuterJoin.EMPTY_LIST), - rightHashTable.getOrElse(key, HashOuterJoin.EMPTY_LIST)) + leftOuterIterator(key, leftHashTable.getOrElse(key, EMPTY_LIST), + rightHashTable.getOrElse(key, EMPTY_LIST)) } case RightOuter => rightHashTable.keysIterator.flatMap { key => - rightOuterIterator(key, leftHashTable.getOrElse(key, HashOuterJoin.EMPTY_LIST), - rightHashTable.getOrElse(key, HashOuterJoin.EMPTY_LIST)) + rightOuterIterator(key, leftHashTable.getOrElse(key, EMPTY_LIST), + rightHashTable.getOrElse(key, EMPTY_LIST)) } case FullOuter => (leftHashTable.keySet ++ rightHashTable.keySet).iterator.flatMap { key => fullOuterIterator(key, - leftHashTable.getOrElse(key, HashOuterJoin.EMPTY_LIST), - rightHashTable.getOrElse(key, HashOuterJoin.EMPTY_LIST)) + leftHashTable.getOrElse(key, EMPTY_LIST), + rightHashTable.getOrElse(key, EMPTY_LIST)) } case x => throw new Exception(s"HashOuterJoin should not take $x as the JoinType") } From 872c170c4d764ae700004f55af32d86173d0081d Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Mon, 11 Aug 2014 22:33:45 -0700 Subject: [PATCH 108/231] [SPARK-2923][MLLIB] Implement some basic BLAS routines Having some basic BLAS operations implemented in MLlib can help simplify the current implementation and improve some performance. Tested on my local machine: ~~~ bin/spark-submit --class org.apache.spark.examples.mllib.BinaryClassification \ examples/target/scala-*/spark-examples-*.jar --algorithm LR --regType L2 \ --regParam 1.0 --numIterations 1000 ~/share/data/rcv1.binary/rcv1_train.binary ~~~ 1. before: ~1m 2. after: ~30s CC: jkbradley Author: Xiangrui Meng Closes #1849 from mengxr/ml-blas and squashes the following commits: ba583a2 [Xiangrui Meng] exclude Vector.copy a4d7d2f [Xiangrui Meng] Merge branch 'master' into ml-blas 6edeab9 [Xiangrui Meng] address comments 940bdeb [Xiangrui Meng] rename MLlibBLAS to BLAS c2a38bc [Xiangrui Meng] enhance dot tests 4cfaac4 [Xiangrui Meng] add apache header 48d01d2 [Xiangrui Meng] add tests for zeros and copy 3b882b1 [Xiangrui Meng] use blas.scal in gradient 735eb23 [Xiangrui Meng] remove d from BLAS routines d2d7d3c [Xiangrui Meng] update gradient and lbfgs 7f78186 [Xiangrui Meng] add zeros to Vectors; add dscal and dcopy to BLAS 14e6645 [Xiangrui Meng] add ddot cbb8273 [Xiangrui Meng] add daxpy test 07db0bb [Xiangrui Meng] Merge branch 'master' into ml-blas e8c326d [Xiangrui Meng] axpy (cherry picked from commit 9038d94e1e50e05de00fd51af4fd7b9280481cdc) Signed-off-by: Xiangrui Meng --- .../org/apache/spark/mllib/linalg/BLAS.scala | 200 ++++++++++++++++++ .../apache/spark/mllib/linalg/Vectors.scala | 35 ++- .../spark/mllib/optimization/Gradient.scala | 60 ++---- .../spark/mllib/optimization/LBFGS.scala | 39 ++-- .../apache/spark/mllib/linalg/BLASSuite.scala | 129 +++++++++++ .../spark/mllib/linalg/VectorsSuite.scala | 30 +++ project/MimaExcludes.scala | 5 +- 7 files changed, 432 insertions(+), 66 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala create mode 100644 mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala new file mode 100644 index 0000000000000..70e23033c8754 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala @@ -0,0 +1,200 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.linalg + +import com.github.fommil.netlib.{BLAS => NetlibBLAS, F2jBLAS} + +/** + * BLAS routines for MLlib's vectors and matrices. + */ +private[mllib] object BLAS extends Serializable { + + @transient private var _f2jBLAS: NetlibBLAS = _ + + // For level-1 routines, we use Java implementation. + private def f2jBLAS: NetlibBLAS = { + if (_f2jBLAS == null) { + _f2jBLAS = new F2jBLAS + } + _f2jBLAS + } + + /** + * y += a * x + */ + def axpy(a: Double, x: Vector, y: Vector): Unit = { + require(x.size == y.size) + y match { + case dy: DenseVector => + x match { + case sx: SparseVector => + axpy(a, sx, dy) + case dx: DenseVector => + axpy(a, dx, dy) + case _ => + throw new UnsupportedOperationException( + s"axpy doesn't support x type ${x.getClass}.") + } + case _ => + throw new IllegalArgumentException( + s"axpy only supports adding to a dense vector but got type ${y.getClass}.") + } + } + + /** + * y += a * x + */ + private def axpy(a: Double, x: DenseVector, y: DenseVector): Unit = { + val n = x.size + f2jBLAS.daxpy(n, a, x.values, 1, y.values, 1) + } + + /** + * y += a * x + */ + private def axpy(a: Double, x: SparseVector, y: DenseVector): Unit = { + val nnz = x.indices.size + if (a == 1.0) { + var k = 0 + while (k < nnz) { + y.values(x.indices(k)) += x.values(k) + k += 1 + } + } else { + var k = 0 + while (k < nnz) { + y.values(x.indices(k)) += a * x.values(k) + k += 1 + } + } + } + + /** + * dot(x, y) + */ + def dot(x: Vector, y: Vector): Double = { + require(x.size == y.size) + (x, y) match { + case (dx: DenseVector, dy: DenseVector) => + dot(dx, dy) + case (sx: SparseVector, dy: DenseVector) => + dot(sx, dy) + case (dx: DenseVector, sy: SparseVector) => + dot(sy, dx) + case (sx: SparseVector, sy: SparseVector) => + dot(sx, sy) + case _ => + throw new IllegalArgumentException(s"dot doesn't support (${x.getClass}, ${y.getClass}).") + } + } + + /** + * dot(x, y) + */ + private def dot(x: DenseVector, y: DenseVector): Double = { + val n = x.size + f2jBLAS.ddot(n, x.values, 1, y.values, 1) + } + + /** + * dot(x, y) + */ + private def dot(x: SparseVector, y: DenseVector): Double = { + val nnz = x.indices.size + var sum = 0.0 + var k = 0 + while (k < nnz) { + sum += x.values(k) * y.values(x.indices(k)) + k += 1 + } + sum + } + + /** + * dot(x, y) + */ + private def dot(x: SparseVector, y: SparseVector): Double = { + var kx = 0 + val nnzx = x.indices.size + var ky = 0 + val nnzy = y.indices.size + var sum = 0.0 + // y catching x + while (kx < nnzx && ky < nnzy) { + val ix = x.indices(kx) + while (ky < nnzy && y.indices(ky) < ix) { + ky += 1 + } + if (ky < nnzy && y.indices(ky) == ix) { + sum += x.values(kx) * y.values(ky) + ky += 1 + } + kx += 1 + } + sum + } + + /** + * y = x + */ + def copy(x: Vector, y: Vector): Unit = { + val n = y.size + require(x.size == n) + y match { + case dy: DenseVector => + x match { + case sx: SparseVector => + var i = 0 + var k = 0 + val nnz = sx.indices.size + while (k < nnz) { + val j = sx.indices(k) + while (i < j) { + dy.values(i) = 0.0 + i += 1 + } + dy.values(i) = sx.values(k) + i += 1 + k += 1 + } + while (i < n) { + dy.values(i) = 0.0 + i += 1 + } + case dx: DenseVector => + Array.copy(dx.values, 0, dy.values, 0, n) + } + case _ => + throw new IllegalArgumentException(s"y must be dense in copy but got ${y.getClass}") + } + } + + /** + * x = a * x + */ + def scal(a: Double, x: Vector): Unit = { + x match { + case sx: SparseVector => + f2jBLAS.dscal(sx.values.size, a, sx.values, 1) + case dx: DenseVector => + f2jBLAS.dscal(dx.values.size, a, dx.values, 1) + case _ => + throw new IllegalArgumentException(s"scal doesn't support vector type ${x.getClass}.") + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index 77b3e8c714997..a45781d12e41e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -18,7 +18,7 @@ package org.apache.spark.mllib.linalg import java.lang.{Double => JavaDouble, Integer => JavaInteger, Iterable => JavaIterable} -import java.util.Arrays +import java.util import scala.annotation.varargs import scala.collection.JavaConverters._ @@ -30,6 +30,8 @@ import org.apache.spark.SparkException /** * Represents a numeric vector, whose index type is Int and value type is Double. + * + * Note: Users should not implement this interface. */ trait Vector extends Serializable { @@ -46,12 +48,12 @@ trait Vector extends Serializable { override def equals(other: Any): Boolean = { other match { case v: Vector => - Arrays.equals(this.toArray, v.toArray) + util.Arrays.equals(this.toArray, v.toArray) case _ => false } } - override def hashCode(): Int = Arrays.hashCode(this.toArray) + override def hashCode(): Int = util.Arrays.hashCode(this.toArray) /** * Converts the instance to a breeze vector. @@ -63,6 +65,13 @@ trait Vector extends Serializable { * @param i index */ def apply(i: Int): Double = toBreeze(i) + + /** + * Makes a deep copy of this vector. + */ + def copy: Vector = { + throw new NotImplementedError(s"copy is not implemented for ${this.getClass}.") + } } /** @@ -127,6 +136,16 @@ object Vectors { }.toSeq) } + /** + * Creates a dense vector of all zeros. + * + * @param size vector size + * @return a zero vector + */ + def zeros(size: Int): Vector = { + new DenseVector(new Array[Double](size)) + } + /** * Parses a string resulted from `Vector#toString` into * an [[org.apache.spark.mllib.linalg.Vector]]. @@ -142,7 +161,7 @@ object Vectors { case Seq(size: Double, indices: Array[Double], values: Array[Double]) => Vectors.sparse(size.toInt, indices.map(_.toInt), values) case other => - throw new SparkException(s"Cannot parse $other.") + throw new SparkException(s"Cannot parse $other.") } } @@ -183,6 +202,10 @@ class DenseVector(val values: Array[Double]) extends Vector { private[mllib] override def toBreeze: BV[Double] = new BDV[Double](values) override def apply(i: Int) = values(i) + + override def copy: DenseVector = { + new DenseVector(values.clone()) + } } /** @@ -213,5 +236,9 @@ class SparseVector( data } + override def copy: SparseVector = { + new SparseVector(size, indices.clone(), values.clone()) + } + private[mllib] override def toBreeze: BV[Double] = new BSV[Double](indices, values, size) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala index 9d82f011e674a..fdd67160114ca 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala @@ -17,10 +17,9 @@ package org.apache.spark.mllib.optimization -import breeze.linalg.{axpy => brzAxpy} - import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.mllib.linalg.{Vectors, Vector} +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.linalg.BLAS.{axpy, dot, scal} /** * :: DeveloperApi :: @@ -61,11 +60,10 @@ abstract class Gradient extends Serializable { @DeveloperApi class LogisticGradient extends Gradient { override def compute(data: Vector, label: Double, weights: Vector): (Vector, Double) = { - val brzData = data.toBreeze - val brzWeights = weights.toBreeze - val margin: Double = -1.0 * brzWeights.dot(brzData) + val margin = -1.0 * dot(data, weights) val gradientMultiplier = (1.0 / (1.0 + math.exp(margin))) - label - val gradient = brzData * gradientMultiplier + val gradient = data.copy + scal(gradientMultiplier, gradient) val loss = if (label > 0) { math.log1p(math.exp(margin)) // log1p is log(1+p) but more accurate for small p @@ -73,7 +71,7 @@ class LogisticGradient extends Gradient { math.log1p(math.exp(margin)) - margin } - (Vectors.fromBreeze(gradient), loss) + (gradient, loss) } override def compute( @@ -81,13 +79,9 @@ class LogisticGradient extends Gradient { label: Double, weights: Vector, cumGradient: Vector): Double = { - val brzData = data.toBreeze - val brzWeights = weights.toBreeze - val margin: Double = -1.0 * brzWeights.dot(brzData) + val margin = -1.0 * dot(data, weights) val gradientMultiplier = (1.0 / (1.0 + math.exp(margin))) - label - - brzAxpy(gradientMultiplier, brzData, cumGradient.toBreeze) - + axpy(gradientMultiplier, data, cumGradient) if (label > 0) { math.log1p(math.exp(margin)) } else { @@ -106,13 +100,11 @@ class LogisticGradient extends Gradient { @DeveloperApi class LeastSquaresGradient extends Gradient { override def compute(data: Vector, label: Double, weights: Vector): (Vector, Double) = { - val brzData = data.toBreeze - val brzWeights = weights.toBreeze - val diff = brzWeights.dot(brzData) - label + val diff = dot(data, weights) - label val loss = diff * diff - val gradient = brzData * (2.0 * diff) - - (Vectors.fromBreeze(gradient), loss) + val gradient = data.copy + scal(2.0 * diff, gradient) + (gradient, loss) } override def compute( @@ -120,12 +112,8 @@ class LeastSquaresGradient extends Gradient { label: Double, weights: Vector, cumGradient: Vector): Double = { - val brzData = data.toBreeze - val brzWeights = weights.toBreeze - val diff = brzWeights.dot(brzData) - label - - brzAxpy(2.0 * diff, brzData, cumGradient.toBreeze) - + val diff = dot(data, weights) - label + axpy(2.0 * diff, data, cumGradient) diff * diff } } @@ -139,18 +127,16 @@ class LeastSquaresGradient extends Gradient { @DeveloperApi class HingeGradient extends Gradient { override def compute(data: Vector, label: Double, weights: Vector): (Vector, Double) = { - val brzData = data.toBreeze - val brzWeights = weights.toBreeze - val dotProduct = brzWeights.dot(brzData) - + val dotProduct = dot(data, weights) // Our loss function with {0, 1} labels is max(0, 1 - (2y – 1) (f_w(x))) // Therefore the gradient is -(2y - 1)*x val labelScaled = 2 * label - 1.0 - if (1.0 > labelScaled * dotProduct) { - (Vectors.fromBreeze(brzData * (-labelScaled)), 1.0 - labelScaled * dotProduct) + val gradient = data.copy + scal(-labelScaled, gradient) + (gradient, 1.0 - labelScaled * dotProduct) } else { - (Vectors.dense(new Array[Double](weights.size)), 0.0) + (Vectors.sparse(weights.size, Array.empty, Array.empty), 0.0) } } @@ -159,16 +145,12 @@ class HingeGradient extends Gradient { label: Double, weights: Vector, cumGradient: Vector): Double = { - val brzData = data.toBreeze - val brzWeights = weights.toBreeze - val dotProduct = brzWeights.dot(brzData) - + val dotProduct = dot(data, weights) // Our loss function with {0, 1} labels is max(0, 1 - (2y – 1) (f_w(x))) // Therefore the gradient is -(2y - 1)*x val labelScaled = 2 * label - 1.0 - if (1.0 > labelScaled * dotProduct) { - brzAxpy(-labelScaled, brzData, cumGradient.toBreeze) + axpy(-labelScaled, data, cumGradient) 1.0 - labelScaled * dotProduct } else { 0.0 diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala index 26a2b62e76ed0..033fe44f34f3c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala @@ -19,14 +19,15 @@ package org.apache.spark.mllib.optimization import scala.collection.mutable.ArrayBuffer -import breeze.linalg.{DenseVector => BDV, axpy} +import breeze.linalg.{DenseVector => BDV} import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS} -import org.apache.spark.annotation.DeveloperApi import org.apache.spark.Logging -import org.apache.spark.rdd.RDD -import org.apache.spark.mllib.linalg.{Vectors, Vector} +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.linalg.BLAS.axpy import org.apache.spark.mllib.rdd.RDDFunctions._ +import org.apache.spark.rdd.RDD /** * :: DeveloperApi :: @@ -192,31 +193,29 @@ object LBFGS extends Logging { regParam: Double, numExamples: Long) extends DiffFunction[BDV[Double]] { - private var i = 0 - - override def calculate(weights: BDV[Double]) = { + override def calculate(weights: BDV[Double]): (Double, BDV[Double]) = { // Have a local copy to avoid the serialization of CostFun object which is not serializable. + val w = Vectors.fromBreeze(weights) + val n = w.size + val bcW = data.context.broadcast(w) val localGradient = gradient - val n = weights.length - val bcWeights = data.context.broadcast(weights) - val (gradientSum, lossSum) = data.treeAggregate((BDV.zeros[Double](n), 0.0))( + val (gradientSum, lossSum) = data.treeAggregate((Vectors.zeros(n), 0.0))( seqOp = (c, v) => (c, v) match { case ((grad, loss), (label, features)) => val l = localGradient.compute( - features, label, Vectors.fromBreeze(bcWeights.value), Vectors.fromBreeze(grad)) + features, label, bcW.value, grad) (grad, loss + l) }, combOp = (c1, c2) => (c1, c2) match { case ((grad1, loss1), (grad2, loss2)) => - (grad1 += grad2, loss1 + loss2) + axpy(1.0, grad2, grad1) + (grad1, loss1 + loss2) }) /** * regVal is sum of weight squares if it's L2 updater; * for other updater, the same logic is followed. */ - val regVal = updater.compute( - Vectors.fromBreeze(weights), - Vectors.dense(new Array[Double](weights.size)), 0, 1, regParam)._2 + val regVal = updater.compute(w, Vectors.zeros(n), 0, 1, regParam)._2 val loss = lossSum / numExamples + regVal /** @@ -236,17 +235,13 @@ object LBFGS extends Logging { */ // The following gradientTotal is actually the regularization part of gradient. // Will add the gradientSum computed from the data with weights in the next step. - val gradientTotal = weights - updater.compute( - Vectors.fromBreeze(weights), - Vectors.dense(new Array[Double](weights.size)), 1, 1, regParam)._1.toBreeze + val gradientTotal = w.copy + axpy(-1.0, updater.compute(w, Vectors.zeros(n), 1, 1, regParam)._1, gradientTotal) // gradientTotal = gradientSum / numExamples + gradientTotal axpy(1.0 / numExamples, gradientSum, gradientTotal) - i += 1 - - (loss, gradientTotal) + (loss, gradientTotal.toBreeze.asInstanceOf[BDV[Double]]) } } - } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala new file mode 100644 index 0000000000000..1952e6734ecf7 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.linalg + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.mllib.linalg.BLAS._ + +class BLASSuite extends FunSuite { + + test("copy") { + val sx = Vectors.sparse(4, Array(0, 2), Array(1.0, -2.0)) + val dx = Vectors.dense(1.0, 0.0, -2.0, 0.0) + val sy = Vectors.sparse(4, Array(0, 1, 3), Array(2.0, 1.0, 1.0)) + val dy = Array(2.0, 1.0, 0.0, 1.0) + + val dy1 = Vectors.dense(dy.clone()) + copy(sx, dy1) + assert(dy1 ~== dx absTol 1e-15) + + val dy2 = Vectors.dense(dy.clone()) + copy(dx, dy2) + assert(dy2 ~== dx absTol 1e-15) + + intercept[IllegalArgumentException] { + copy(sx, sy) + } + + intercept[IllegalArgumentException] { + copy(dx, sy) + } + + withClue("vector sizes must match") { + intercept[Exception] { + copy(sx, Vectors.dense(0.0, 1.0, 2.0)) + } + } + } + + test("scal") { + val a = 0.1 + val sx = Vectors.sparse(3, Array(0, 2), Array(1.0, -2.0)) + val dx = Vectors.dense(1.0, 0.0, -2.0) + + scal(a, sx) + assert(sx ~== Vectors.sparse(3, Array(0, 2), Array(0.1, -0.2)) absTol 1e-15) + + scal(a, dx) + assert(dx ~== Vectors.dense(0.1, 0.0, -0.2) absTol 1e-15) + } + + test("axpy") { + val alpha = 0.1 + val sx = Vectors.sparse(3, Array(0, 2), Array(1.0, -2.0)) + val dx = Vectors.dense(1.0, 0.0, -2.0) + val dy = Array(2.0, 1.0, 0.0) + val expected = Vectors.dense(2.1, 1.0, -0.2) + + val dy1 = Vectors.dense(dy.clone()) + axpy(alpha, sx, dy1) + assert(dy1 ~== expected absTol 1e-15) + + val dy2 = Vectors.dense(dy.clone()) + axpy(alpha, dx, dy2) + assert(dy2 ~== expected absTol 1e-15) + + val sy = Vectors.sparse(4, Array(0, 1), Array(2.0, 1.0)) + + intercept[IllegalArgumentException] { + axpy(alpha, sx, sy) + } + + intercept[IllegalArgumentException] { + axpy(alpha, dx, sy) + } + + withClue("vector sizes must match") { + intercept[Exception] { + axpy(alpha, sx, Vectors.dense(1.0, 2.0)) + } + } + } + + test("dot") { + val sx = Vectors.sparse(3, Array(0, 2), Array(1.0, -2.0)) + val dx = Vectors.dense(1.0, 0.0, -2.0) + val sy = Vectors.sparse(3, Array(0, 1), Array(2.0, 1.0)) + val dy = Vectors.dense(2.0, 1.0, 0.0) + + assert(dot(sx, sy) ~== 2.0 absTol 1e-15) + assert(dot(sy, sx) ~== 2.0 absTol 1e-15) + assert(dot(sx, dy) ~== 2.0 absTol 1e-15) + assert(dot(dy, sx) ~== 2.0 absTol 1e-15) + assert(dot(dx, dy) ~== 2.0 absTol 1e-15) + assert(dot(dy, dx) ~== 2.0 absTol 1e-15) + + assert(dot(sx, sx) ~== 5.0 absTol 1e-15) + assert(dot(dx, dx) ~== 5.0 absTol 1e-15) + assert(dot(sx, dx) ~== 5.0 absTol 1e-15) + assert(dot(dx, sx) ~== 5.0 absTol 1e-15) + + val sx1 = Vectors.sparse(10, Array(0, 3, 5, 7, 8), Array(1.0, 2.0, 3.0, 4.0, 5.0)) + val sx2 = Vectors.sparse(10, Array(1, 3, 6, 7, 9), Array(1.0, 2.0, 3.0, 4.0, 5.0)) + assert(dot(sx1, sx2) ~== 20.0 absTol 1e-15) + assert(dot(sx2, sx1) ~== 20.0 absTol 1e-15) + + withClue("vector sizes must match") { + intercept[Exception] { + dot(sx, Vectors.dense(2.0, 1.0)) + } + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala index 7972ceea1fe8a..cd651fe2d2ddf 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala @@ -125,4 +125,34 @@ class VectorsSuite extends FunSuite { } } } + + test("zeros") { + assert(Vectors.zeros(3) === Vectors.dense(0.0, 0.0, 0.0)) + } + + test("Vector.copy") { + val sv = Vectors.sparse(4, Array(0, 2), Array(1.0, 2.0)) + val svCopy = sv.copy + (sv, svCopy) match { + case (sv: SparseVector, svCopy: SparseVector) => + assert(sv.size === svCopy.size) + assert(sv.indices === svCopy.indices) + assert(sv.values === svCopy.values) + assert(!sv.indices.eq(svCopy.indices)) + assert(!sv.values.eq(svCopy.values)) + case _ => + throw new RuntimeException(s"copy returned ${svCopy.getClass} on ${sv.getClass}.") + } + + val dv = Vectors.dense(1.0, 0.0, 2.0) + val dvCopy = dv.copy + (dv, dvCopy) match { + case (dv: DenseVector, dvCopy: DenseVector) => + assert(dv.size === dvCopy.size) + assert(dv.values === dvCopy.values) + assert(!dv.values.eq(dvCopy.values)) + case _ => + throw new RuntimeException(s"copy returned ${dvCopy.getClass} on ${dv.getClass}.") + } + } } diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index b4653c72c10b5..6e72035f2c15b 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -111,9 +111,12 @@ object MimaExcludes { ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.util.MulticlassLabelParser"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.util.MulticlassLabelParser$") ) ++ - Seq ( // package-private classes removed in MLlib + Seq( // package-private classes removed in MLlib ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.mllib.regression.GeneralizedLinearAlgorithm.org$apache$spark$mllib$regression$GeneralizedLinearAlgorithm$$prependOne") + ) ++ + Seq( // new Vector methods in MLlib (binary compatible assuming users do not implement Vector) + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.mllib.linalg.Vector.copy") ) case v if v.startsWith("1.0") => Seq( From 2a8117a994c1a86199bd0610ce9a784311b2596d Mon Sep 17 00:00:00 2001 From: Liquan Pei Date: Tue, 12 Aug 2014 00:28:00 -0700 Subject: [PATCH 109/231] [MLlib] Correctly set vectorSize and alpha mengxr Correctly set vectorSize and alpha in Word2Vec training. Author: Liquan Pei Closes #1900 from Ishiihara/Word2Vec-bugfix and squashes the following commits: 85f64f2 [Liquan Pei] correctly set vectorSize and alpha (cherry picked from commit f0060b75ff67ab60babf54149a6860edc53cb6e9) Signed-off-by: Xiangrui Meng --- .../apache/spark/mllib/feature/Word2Vec.scala | 25 +++++++++---------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index 395037e1ec47c..ecd49ea2ff533 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -119,7 +119,6 @@ class Word2Vec extends Serializable with Logging { private val MAX_EXP = 6 private val MAX_CODE_LENGTH = 40 private val MAX_SENTENCE_LENGTH = 1000 - private val layer1Size = vectorSize /** context words from [-window, window] */ private val window = 5 @@ -131,7 +130,6 @@ class Word2Vec extends Serializable with Logging { private var vocabSize = 0 private var vocab: Array[VocabWord] = null private var vocabHash = mutable.HashMap.empty[String, Int] - private var alpha = startingAlpha private def learnVocab(words: RDD[String]): Unit = { vocab = words.map(w => (w, 1)) @@ -287,9 +285,10 @@ class Word2Vec extends Serializable with Logging { val newSentences = sentences.repartition(numPartitions).cache() val initRandom = new XORShiftRandom(seed) var syn0Global = - Array.fill[Float](vocabSize * layer1Size)((initRandom.nextFloat() - 0.5f) / layer1Size) - var syn1Global = new Array[Float](vocabSize * layer1Size) + Array.fill[Float](vocabSize * vectorSize)((initRandom.nextFloat() - 0.5f) / vectorSize) + var syn1Global = new Array[Float](vocabSize * vectorSize) + var alpha = startingAlpha for (k <- 1 to numIterations) { val partial = newSentences.mapPartitionsWithIndex { case (idx, iter) => val random = new XORShiftRandom(seed ^ ((idx + 1) << 16) ^ ((-k - 1) << 8)) @@ -317,24 +316,24 @@ class Word2Vec extends Serializable with Logging { val c = pos - window + a if (c >= 0 && c < sentence.size) { val lastWord = sentence(c) - val l1 = lastWord * layer1Size - val neu1e = new Array[Float](layer1Size) + val l1 = lastWord * vectorSize + val neu1e = new Array[Float](vectorSize) // Hierarchical softmax var d = 0 while (d < bcVocab.value(word).codeLen) { - val l2 = bcVocab.value(word).point(d) * layer1Size + val l2 = bcVocab.value(word).point(d) * vectorSize // Propagate hidden -> output - var f = blas.sdot(layer1Size, syn0, l1, 1, syn1, l2, 1) + var f = blas.sdot(vectorSize, syn0, l1, 1, syn1, l2, 1) if (f > -MAX_EXP && f < MAX_EXP) { val ind = ((f + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2.0)).toInt f = expTable.value(ind) val g = ((1 - bcVocab.value(word).code(d) - f) * alpha).toFloat - blas.saxpy(layer1Size, g, syn1, l2, 1, neu1e, 0, 1) - blas.saxpy(layer1Size, g, syn0, l1, 1, syn1, l2, 1) + blas.saxpy(vectorSize, g, syn1, l2, 1, neu1e, 0, 1) + blas.saxpy(vectorSize, g, syn0, l1, 1, syn1, l2, 1) } d += 1 } - blas.saxpy(layer1Size, 1.0f, neu1e, 0, 1, syn0, l1, 1) + blas.saxpy(vectorSize, 1.0f, neu1e, 0, 1, syn0, l1, 1) } } a += 1 @@ -365,8 +364,8 @@ class Word2Vec extends Serializable with Logging { var i = 0 while (i < vocabSize) { val word = bcVocab.value(i).word - val vector = new Array[Float](layer1Size) - Array.copy(syn0Global, i * layer1Size, vector, 0, layer1Size) + val vector = new Array[Float](vectorSize) + Array.copy(syn0Global, i * vectorSize, vector, 0, vectorSize) word2VecMap += word -> vector i += 1 } From b5f80839806e258de7651d851ef01697eb53c127 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 12 Aug 2014 16:26:01 -0700 Subject: [PATCH 110/231] fix flaky tests Python 2.6 does not handle float error well as 2.7+ Author: Davies Liu Closes #1910 from davies/fix_test and squashes the following commits: 7e51200 [Davies Liu] fix flaky tests (cherry picked from commit 882da57a1c8c075a87909d516b169b624941a6ec) Signed-off-by: Michael Armbrust --- python/pyspark/sql.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 36040463e62a9..27f1d2ddf942a 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -1094,7 +1094,7 @@ def applySchema(self, rdd, schema): ... "SELECT byte1 - 1 AS byte1, byte2 + 1 AS byte2, " + ... "short1 + 1 AS short1, short2 - 1 AS short2, int - 1 AS int, " + ... "float + 1.1 as float FROM table2").collect() - [Row(byte1=126, byte2=-127, short1=-32767, short2=32766, int=2147483646, float=2.1)] + [Row(byte1=126, byte2=-127, short1=-32767, short2=32766, int=2147483646, float=2.1...)] >>> rdd = sc.parallelize([(127, -32768, 1.0, ... datetime(2010, 1, 1, 1, 1, 1), From cffd9bb8d3d025ac2008b54822ee772ec3b28127 Mon Sep 17 00:00:00 2001 From: Ameet Talwalkar Date: Tue, 12 Aug 2014 17:15:21 -0700 Subject: [PATCH 111/231] SPARK-2830 [MLlib]: re-organize mllib documentation As per discussions with Xiangrui, I've reorganized and edited the mllib documentation. Author: Ameet Talwalkar Closes #1908 from atalwalkar/master and squashes the following commits: fe6938a [Ameet Talwalkar] made xiangruis suggested changes 840028b [Ameet Talwalkar] made xiangruis suggested changes 7ec366a [Ameet Talwalkar] reorganize and edit mllib documentation (cherry picked from commit c235b83e2782cce0626ecc403c0a67e442be52c1) Signed-off-by: Xiangrui Meng --- docs/mllib-basics.md | 117 +++++---------------- docs/mllib-classification-regression.md | 37 +++++++ docs/mllib-clustering.md | 15 +-- docs/mllib-collaborative-filtering.md | 21 ++-- docs/mllib-dimensionality-reduction.md | 44 ++++---- docs/mllib-feature-extraction.md | 12 +++ docs/mllib-guide.md | 30 +++--- docs/mllib-linear-methods.md | 134 ++++++++++++------------ docs/mllib-naive-bayes.md | 32 +++--- docs/mllib-stats.md | 95 +++++++++++++++++ 10 files changed, 317 insertions(+), 220 deletions(-) create mode 100644 docs/mllib-classification-regression.md create mode 100644 docs/mllib-feature-extraction.md create mode 100644 docs/mllib-stats.md diff --git a/docs/mllib-basics.md b/docs/mllib-basics.md index f9585251fafac..8752df412950a 100644 --- a/docs/mllib-basics.md +++ b/docs/mllib-basics.md @@ -9,17 +9,17 @@ displayTitle: MLlib - Basics MLlib supports local vectors and matrices stored on a single machine, as well as distributed matrices backed by one or more RDDs. -In the current implementation, local vectors and matrices are simple data models -to serve public interfaces. The underlying linear algebra operations are provided by +Local vectors and local matrices are simple data models +that serve as public interfaces. The underlying linear algebra operations are provided by [Breeze](http://www.scalanlp.org/) and [jblas](http://jblas.org/). -A training example used in supervised learning is called "labeled point" in MLlib. +A training example used in supervised learning is called a "labeled point" in MLlib. ## Local vector A local vector has integer-typed and 0-based indices and double-typed values, stored on a single machine. MLlib supports two types of local vectors: dense and sparse. A dense vector is backed by a double array representing its entry values, while a sparse vector is backed by two parallel -arrays: indices and values. For example, a vector $(1.0, 0.0, 3.0)$ can be represented in dense +arrays: indices and values. For example, a vector `(1.0, 0.0, 3.0)` can be represented in dense format as `[1.0, 0.0, 3.0]` or in sparse format as `(3, [0, 2], [1.0, 3.0])`, where `3` is the size of the vector. @@ -44,8 +44,7 @@ val sv1: Vector = Vectors.sparse(3, Array(0, 2), Array(1.0, 3.0)) val sv2: Vector = Vectors.sparse(3, Seq((0, 1.0), (2, 3.0))) {% endhighlight %} -***Note*** - +***Note:*** Scala imports `scala.collection.immutable.Vector` by default, so you have to import `org.apache.spark.mllib.linalg.Vector` explicitly to use MLlib's `Vector`. @@ -110,8 +109,8 @@ sv2 = sps.csc_matrix((np.array([1.0, 3.0]), np.array([0, 2]), np.array([0, 2])), A labeled point is a local vector, either dense or sparse, associated with a label/response. In MLlib, labeled points are used in supervised learning algorithms. We use a double to store a label, so we can use labeled points in both regression and classification. -For binary classification, label should be either $0$ (negative) or $1$ (positive). -For multiclass classification, labels should be class indices staring from zero: $0, 1, 2, \ldots$. +For binary classification, a label should be either `0` (negative) or `1` (positive). +For multiclass classification, labels should be class indices starting from zero: `0, 1, 2, ...`.
    @@ -172,7 +171,7 @@ neg = LabeledPoint(0.0, SparseVector(3, [0, 2], [1.0, 3.0])) It is very common in practice to have sparse training data. MLlib supports reading training examples stored in `LIBSVM` format, which is the default format used by [`LIBSVM`](http://www.csie.ntu.edu.tw/~cjlin/libsvm/) and -[`LIBLINEAR`](http://www.csie.ntu.edu.tw/~cjlin/liblinear/). It is a text format. Each line +[`LIBLINEAR`](http://www.csie.ntu.edu.tw/~cjlin/liblinear/). It is a text format in which each line represents a labeled sparse feature vector using the following format: ~~~ @@ -226,7 +225,7 @@ examples = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") ## Local matrix A local matrix has integer-typed row and column indices and double-typed values, stored on a single -machine. MLlib supports dense matrix, whose entry values are stored in a single double array in +machine. MLlib supports dense matrices, whose entry values are stored in a single double array in column major. For example, the following matrix `\[ \begin{pmatrix} 1.0 & 2.0 \\ 3.0 & 4.0 \\ @@ -234,7 +233,6 @@ column major. For example, the following matrix `\[ \begin{pmatrix} \end{pmatrix} \]` is stored in a one-dimensional array `[1.0, 3.0, 5.0, 2.0, 4.0, 6.0]` with the matrix size `(3, 2)`. -We are going to add sparse matrix in the next release.
    @@ -242,7 +240,7 @@ We are going to add sparse matrix in the next release. The base class of local matrices is [`Matrix`](api/scala/index.html#org.apache.spark.mllib.linalg.Matrix), and we provide one implementation: [`DenseMatrix`](api/scala/index.html#org.apache.spark.mllib.linalg.DenseMatrix). -Sparse matrix will be added in the next release. We recommend using the factory methods implemented +We recommend using the factory methods implemented in [`Matrices`](api/scala/index.html#org.apache.spark.mllib.linalg.Matrices) to create local matrices. @@ -259,7 +257,7 @@ val dm: Matrix = Matrices.dense(3, 2, Array(1.0, 3.0, 5.0, 2.0, 4.0, 6.0)) The base class of local matrices is [`Matrix`](api/java/org/apache/spark/mllib/linalg/Matrix.html), and we provide one implementation: [`DenseMatrix`](api/java/org/apache/spark/mllib/linalg/DenseMatrix.html). -Sparse matrix will be added in the next release. We recommend using the factory methods implemented +We recommend using the factory methods implemented in [`Matrices`](api/java/org/apache/spark/mllib/linalg/Matrices.html) to create local matrices. @@ -279,28 +277,30 @@ Matrix dm = Matrices.dense(3, 2, new double[] {1.0, 3.0, 5.0, 2.0, 4.0, 6.0}); A distributed matrix has long-typed row and column indices and double-typed values, stored distributively in one or more RDDs. It is very important to choose the right format to store large and distributed matrices. Converting a distributed matrix to a different format may require a -global shuffle, which is quite expensive. We implemented three types of distributed matrices in -this release and will add more types in the future. +global shuffle, which is quite expensive. Three types of distributed matrices have been implemented +so far. The basic type is called `RowMatrix`. A `RowMatrix` is a row-oriented distributed matrix without meaningful row indices, e.g., a collection of feature vectors. It is backed by an RDD of its rows, where each row is a local vector. -We assume that the number of columns is not huge for a `RowMatrix`. +We assume that the number of columns is not huge for a `RowMatrix` so that a single +local vector can be reasonably communicated to the driver and can also be stored / +operated on using a single node. An `IndexedRowMatrix` is similar to a `RowMatrix` but with row indices, -which can be used for identifying rows and joins. -A `CoordinateMatrix` is a distributed matrix stored in [coordinate list (COO)](https://en.wikipedia.org/wiki/Sparse_matrix) format, +which can be used for identifying rows and executing joins. +A `CoordinateMatrix` is a distributed matrix stored in [coordinate list (COO)](https://en.wikipedia.org/wiki/Sparse_matrix#Coordinate_list_.28COO.29) format, backed by an RDD of its entries. ***Note*** The underlying RDDs of a distributed matrix must be deterministic, because we cache the matrix size. -It is always error-prone to have non-deterministic RDDs. +In general the use of non-deterministic RDDs can lead to errors. ### RowMatrix A `RowMatrix` is a row-oriented distributed matrix without meaningful row indices, backed by an RDD -of its rows, where each row is a local vector. This is similar to `data matrix` in the context of -multivariate statistics. Since each row is represented by a local vector, the number of columns is +of its rows, where each row is a local vector. +Since each row is represented by a local vector, the number of columns is limited by the integer range but it should be much smaller in practice.
    @@ -344,70 +344,10 @@ long n = mat.numCols();
    -#### Multivariate summary statistics - -We provide column summary statistics for `RowMatrix`. -If the number of columns is not large, say, smaller than 3000, you can also compute -the covariance matrix as a local matrix, which requires $\mathcal{O}(n^2)$ storage where $n$ is the -number of columns. The total CPU time is $\mathcal{O}(m n^2)$, where $m$ is the number of rows, -which could be faster if the rows are sparse. - -
    -
    - -[`RowMatrix#computeColumnSummaryStatistics`](api/scala/index.html#org.apache.spark.mllib.linalg.distributed.RowMatrix) returns an instance of -[`MultivariateStatisticalSummary`](api/scala/index.html#org.apache.spark.mllib.stat.MultivariateStatisticalSummary), -which contains the column-wise max, min, mean, variance, and number of nonzeros, as well as the -total count. - -{% highlight scala %} -import org.apache.spark.mllib.linalg.Matrix -import org.apache.spark.mllib.linalg.distributed.RowMatrix -import org.apache.spark.mllib.stat.MultivariateStatisticalSummary - -val mat: RowMatrix = ... // a RowMatrix - -// Compute column summary statistics. -val summary: MultivariateStatisticalSummary = mat.computeColumnSummaryStatistics() -println(summary.mean) // a dense vector containing the mean value for each column -println(summary.variance) // column-wise variance -println(summary.numNonzeros) // number of nonzeros in each column - -// Compute the covariance matrix. -val cov: Matrix = mat.computeCovariance() -{% endhighlight %} -
    - -
    - -[`RowMatrix#computeColumnSummaryStatistics`](api/java/org/apache/spark/mllib/linalg/distributed/RowMatrix.html#computeColumnSummaryStatistics()) returns an instance of -[`MultivariateStatisticalSummary`](api/java/org/apache/spark/mllib/stat/MultivariateStatisticalSummary.html), -which contains the column-wise max, min, mean, variance, and number of nonzeros, as well as the -total count. - -{% highlight java %} -import org.apache.spark.mllib.linalg.Matrix; -import org.apache.spark.mllib.linalg.distributed.RowMatrix; -import org.apache.spark.mllib.stat.MultivariateStatisticalSummary; - -RowMatrix mat = ... // a RowMatrix - -// Compute column summary statistics. -MultivariateStatisticalSummary summary = mat.computeColumnSummaryStatistics(); -System.out.println(summary.mean()); // a dense vector containing the mean value for each column -System.out.println(summary.variance()); // column-wise variance -System.out.println(summary.numNonzeros()); // number of nonzeros in each column - -// Compute the covariance matrix. -Matrix cov = mat.computeCovariance(); -{% endhighlight %} -
    -
    - ### IndexedRowMatrix An `IndexedRowMatrix` is similar to a `RowMatrix` but with meaningful row indices. It is backed by -an RDD of indexed rows, which each row is represented by its index (long-typed) and a local vector. +an RDD of indexed rows, so that each row is represented by its index (long-typed) and a local vector.
    @@ -467,7 +407,7 @@ RowMatrix rowMat = mat.toRowMatrix(); A `CoordinateMatrix` is a distributed matrix backed by an RDD of its entries. Each entry is a tuple of `(i: Long, j: Long, value: Double)`, where `i` is the row index, `j` is the column index, and -`value` is the entry value. A `CoordinateMatrix` should be used only in the case when both +`value` is the entry value. A `CoordinateMatrix` should be used only when both dimensions of the matrix are huge and the matrix is very sparse.
    @@ -477,9 +417,9 @@ A [`CoordinateMatrix`](api/scala/index.html#org.apache.spark.mllib.linalg.distributed.CoordinateMatrix) can be created from an `RDD[MatrixEntry]` instance, where [`MatrixEntry`](api/scala/index.html#org.apache.spark.mllib.linalg.distributed.MatrixEntry) is a -wrapper over `(Long, Long, Double)`. A `CoordinateMatrix` can be converted to a `IndexedRowMatrix` -with sparse rows by calling `toIndexedRowMatrix`. In this release, we do not provide other -computation for `CoordinateMatrix`. +wrapper over `(Long, Long, Double)`. A `CoordinateMatrix` can be converted to an `IndexedRowMatrix` +with sparse rows by calling `toIndexedRowMatrix`. Other computations for +`CoordinateMatrix` are not currently supported. {% highlight scala %} import org.apache.spark.mllib.linalg.distributed.{CoordinateMatrix, MatrixEntry} @@ -503,8 +443,9 @@ A [`CoordinateMatrix`](api/java/org/apache/spark/mllib/linalg/distributed/CoordinateMatrix.html) can be created from a `JavaRDD` instance, where [`MatrixEntry`](api/java/org/apache/spark/mllib/linalg/distributed/MatrixEntry.html) is a -wrapper over `(long, long, double)`. A `CoordinateMatrix` can be converted to a `IndexedRowMatrix` -with sparse rows by calling `toIndexedRowMatrix`. +wrapper over `(long, long, double)`. A `CoordinateMatrix` can be converted to an `IndexedRowMatrix` +with sparse rows by calling `toIndexedRowMatrix`. Other computations for +`CoordinateMatrix` are not currently supported. {% highlight java %} import org.apache.spark.api.java.JavaRDD; diff --git a/docs/mllib-classification-regression.md b/docs/mllib-classification-regression.md new file mode 100644 index 0000000000000..719cc95767b00 --- /dev/null +++ b/docs/mllib-classification-regression.md @@ -0,0 +1,37 @@ +--- +layout: global +title: Classification and Regression - MLlib +displayTitle: MLlib - Classification and Regression +--- + +MLlib supports various methods for +[binary classification](http://en.wikipedia.org/wiki/Binary_classification), +[multiclass +classification](http://en.wikipedia.org/wiki/Multiclass_classification), and +[regression analysis](http://en.wikipedia.org/wiki/Regression_analysis). The table below outlines +the supported algorithms for each type of problem. + + + + + + + + + + + + + + + + +
    Problem TypeSupported Methods
    Binary Classificationlinear SVMs, logistic regression, decision trees, naive Bayes
    Multiclass Classificationdecision trees, naive Bayes
    Regressionlinear least squares, Lasso, ridge regression, decision trees
    + +More details for these methods can be found here: + +* [Linear models](mllib-linear-methods.html) + * [binary classification (SVMs, logistic regression)](mllib-linear-methods.html#binary-classification) + * [linear regression (least squares, Lasso, ridge)](mllib-linear-methods.html#linear-least-squares-lasso-and-ridge-regression) +* [Decision trees](mllib-decision-tree.html) +* [Naive Bayes](mllib-naive-bayes.html) diff --git a/docs/mllib-clustering.md b/docs/mllib-clustering.md index 561de48910132..dfd9cd572888c 100644 --- a/docs/mllib-clustering.md +++ b/docs/mllib-clustering.md @@ -38,7 +38,7 @@ a given dataset, the algorithm returns the best clustering result).
    -Following code snippets can be executed in `spark-shell`. +The following code snippets can be executed in `spark-shell`. In the following example after loading and parsing data, we use the [`KMeans`](api/scala/index.html#org.apache.spark.mllib.clustering.KMeans) object to cluster the data @@ -70,7 +70,7 @@ All of MLlib's methods use Java-friendly types, so you can import and call them way you do in Scala. The only caveat is that the methods take Scala RDD objects, while the Spark Java API uses a separate `JavaRDD` class. You can convert a Java RDD to a Scala one by calling `.rdd()` on your `JavaRDD` object. A standalone application example -that is equivalent to the provided example in Scala is given bellow: +that is equivalent to the provided example in Scala is given below: {% highlight java %} import org.apache.spark.api.java.*; @@ -113,14 +113,15 @@ public class KMeansExample { } {% endhighlight %} -In order to run the above standalone application using Spark framework make -sure that you follow the instructions provided at section [Standalone -Applications](quick-start.html) of the quick-start guide. What is more, you -should include to your build file *spark-mllib* as a dependency. +In order to run the above standalone application, follow the instructions +provided in the [Standalone +Applications](quick-start.html#standalone-applications) section of the Spark +quick-start guide. Be sure to also include *spark-mllib* to your build file as +a dependency.
    -Following examples can be tested in the PySpark shell. +The following examples can be tested in the PySpark shell. In the following example after loading and parsing data, we use the KMeans object to cluster the data into two clusters. The number of desired clusters is passed to the algorithm. We then compute diff --git a/docs/mllib-collaborative-filtering.md b/docs/mllib-collaborative-filtering.md index 0d28b5f7c89b3..ab10b2f01f87b 100644 --- a/docs/mllib-collaborative-filtering.md +++ b/docs/mllib-collaborative-filtering.md @@ -14,13 +14,13 @@ is commonly used for recommender systems. These techniques aim to fill in the missing entries of a user-item association matrix. MLlib currently supports model-based collaborative filtering, in which users and products are described by a small set of latent factors that can be used to predict missing entries. -In particular, we implement the [alternating least squares +MLlib uses the [alternating least squares (ALS)](http://dl.acm.org/citation.cfm?id=1608614) algorithm to learn these latent factors. The implementation in MLlib has the following parameters: * *numBlocks* is the number of blocks used to parallelize computation (set to -1 to auto-configure). -* *rank* is the number of latent factors in our model. +* *rank* is the number of latent factors in the model. * *iterations* is the number of iterations to run. * *lambda* specifies the regularization parameter in ALS. * *implicitPrefs* specifies whether to use the *explicit feedback* ALS variant or one adapted for @@ -86,8 +86,8 @@ val MSE = ratesAndPreds.map { case ((user, product), (r1, r2)) => println("Mean Squared Error = " + MSE) {% endhighlight %} -If the rating matrix is derived from other source of information (i.e., it is inferred from -other signals), you can use the trainImplicit method to get better results. +If the rating matrix is derived from another source of information (e.g., it is inferred from +other signals), you can use the `trainImplicit` method to get better results. {% highlight scala %} val alpha = 0.01 @@ -174,10 +174,11 @@ public class CollaborativeFiltering { } {% endhighlight %} -In order to run the above standalone application using Spark framework make -sure that you follow the instructions provided at section [Standalone -Applications](quick-start.html) of the quick-start guide. What is more, you -should include to your build file *spark-mllib* as a dependency. +In order to run the above standalone application, follow the instructions +provided in the [Standalone +Applications](quick-start.html#standalone-applications) section of the Spark +quick-start guide. Be sure to also include *spark-mllib* to your build file as +a dependency.
    @@ -219,5 +220,5 @@ model = ALS.trainImplicit(ratings, rank, numIterations, alpha = 0.01) ## Tutorial -[AMP Camp](http://ampcamp.berkeley.edu/) provides a hands-on tutorial for -[personalized movie recommendation with MLlib](http://ampcamp.berkeley.edu/big-data-mini-course/movie-recommendation-with-mllib.html). +The [training exercises](https://databricks-training.s3.amazonaws.com/index.html) from the Spark Summit 2014 include a hands-on tutorial for +[personalized movie recommendation with MLlib](https://databricks-training.s3.amazonaws.com/movie-recommendation-with-mllib.html). diff --git a/docs/mllib-dimensionality-reduction.md b/docs/mllib-dimensionality-reduction.md index 8e434998c15ea..065d646496131 100644 --- a/docs/mllib-dimensionality-reduction.md +++ b/docs/mllib-dimensionality-reduction.md @@ -9,9 +9,9 @@ displayTitle: MLlib - Dimensionality Reduction [Dimensionality reduction](http://en.wikipedia.org/wiki/Dimensionality_reduction) is the process of reducing the number of variables under consideration. -It is used to extract latent features from raw and noisy features, +It can be used to extract latent features from raw and noisy features or compress data while maintaining the structure. -In this release, we provide preliminary support for dimensionality reduction on tall-and-skinny matrices. +MLlib provides support for dimensionality reduction on tall-and-skinny matrices. ## Singular value decomposition (SVD) @@ -30,17 +30,17 @@ where * $V$ is an orthonormal matrix, whose columns are called right singular vectors. For large matrices, usually we don't need the complete factorization but only the top singular -values and its associated singular vectors. This can save storage, and more importantly, de-noise +values and its associated singular vectors. This can save storage, de-noise and recover the low-rank structure of the matrix. -If we keep the top $k$ singular values, then the dimensions of the return will be: +If we keep the top $k$ singular values, then the dimensions of the resulting low-rank matrix will be: * `$U$`: `$m \times k$`, * `$\Sigma$`: `$k \times k$`, * `$V$`: `$n \times k$`. -In this release, we provide SVD computation to row-oriented matrices that have only a few columns, -say, less than $1000$, but many rows, which we call *tall-and-skinny*. +MLlib provides SVD functionality to row-oriented matrices that have only a few columns, +say, less than $1000$, but many rows, i.e., *tall-and-skinny* matrices.
    @@ -58,15 +58,10 @@ val s: Vector = svd.s // The singular values are stored in a local dense vector. val V: Matrix = svd.V // The V factor is a local dense matrix. {% endhighlight %} -Same code applies to `IndexedRowMatrix`. -The only difference that the `U` matrix becomes an `IndexedRowMatrix`. +The same code applies to `IndexedRowMatrix` if `U` is defined as an +`IndexedRowMatrix`.
    -In order to run the following standalone application using Spark framework make -sure that you follow the instructions provided at section [Standalone -Applications](quick-start.html) of the quick-start guide. What is more, you -should include to your build file *spark-mllib* as a dependency. - {% highlight java %} import java.util.LinkedList; @@ -104,8 +99,16 @@ public class SVD { } } {% endhighlight %} -Same code applies to `IndexedRowMatrix`. -The only difference that the `U` matrix becomes an `IndexedRowMatrix`. + +The same code applies to `IndexedRowMatrix` if `U` is defined as an +`IndexedRowMatrix`. + +In order to run the above standalone application, follow the instructions +provided in the [Standalone +Applications](quick-start.html#standalone-applications) section of the Spark +quick-start guide. Be sure to also include *spark-mllib* to your build file as +a dependency. +
    @@ -116,7 +119,7 @@ statistical method to find a rotation such that the first coordinate has the lar possible, and each succeeding coordinate in turn has the largest variance possible. The columns of the rotation matrix are called principal components. PCA is used widely in dimensionality reduction. -In this release, we implement PCA for tall-and-skinny matrices stored in row-oriented format. +MLlib supports PCA for tall-and-skinny matrices stored in row-oriented format.
    @@ -180,9 +183,10 @@ public class PCA { } {% endhighlight %} -In order to run the above standalone application using Spark framework make -sure that you follow the instructions provided at section [Standalone -Applications](quick-start.html) of the quick-start guide. What is more, you -should include to your build file *spark-mllib* as a dependency. +In order to run the above standalone application, follow the instructions +provided in the [Standalone +Applications](quick-start.html#standalone-applications) section of the Spark +quick-start guide. Be sure to also include *spark-mllib* to your build file as +a dependency.
    diff --git a/docs/mllib-feature-extraction.md b/docs/mllib-feature-extraction.md new file mode 100644 index 0000000000000..21453cb9cd8c9 --- /dev/null +++ b/docs/mllib-feature-extraction.md @@ -0,0 +1,12 @@ +--- +layout: global +title: Feature Extraction - MLlib +displayTitle: MLlib - Feature Extraction +--- + +* Table of contents +{:toc} + +## Word2Vec + +## TFIDF diff --git a/docs/mllib-guide.md b/docs/mllib-guide.md index 95ee6bc96801f..23d5a0c4607af 100644 --- a/docs/mllib-guide.md +++ b/docs/mllib-guide.md @@ -3,18 +3,19 @@ layout: global title: Machine Learning Library (MLlib) --- -MLlib is a Spark implementation of some common machine learning algorithms and utilities, +MLlib is Spark's scalable machine learning library consisting of common learning algorithms and utilities, including classification, regression, clustering, collaborative -filtering, dimensionality reduction, as well as underlying optimization primitives: +filtering, dimensionality reduction, as well as underlying optimization primitives, as outlined below: -* [Basics](mllib-basics.html) - * data types +* [Data types](mllib-basics.html) +* [Basic statistics](mllib-stats.html) + * data generators + * stratified sampling * summary statistics -* Classification and regression - * [linear support vector machine (SVM)](mllib-linear-methods.html#linear-support-vector-machine-svm) - * [logistic regression](mllib-linear-methods.html#logistic-regression) - * [linear least squares, Lasso, and ridge regression](mllib-linear-methods.html#linear-least-squares-lasso-and-ridge-regression) - * [decision tree](mllib-decision-tree.html) + * hypothesis testing +* [Classification and regression](mllib-classification-regression.html) + * [linear models (SVMs, logistic regression, linear regression)](mllib-linear-methods.html) + * [decision trees](mllib-decision-tree.html) * [naive Bayes](mllib-naive-bayes.html) * [Collaborative filtering](mllib-collaborative-filtering.html) * alternating least squares (ALS) @@ -23,17 +24,18 @@ filtering, dimensionality reduction, as well as underlying optimization primitiv * [Dimensionality reduction](mllib-dimensionality-reduction.html) * singular value decomposition (SVD) * principal component analysis (PCA) -* [Optimization](mllib-optimization.html) +* [Feature extraction and transformation](mllib-feature-extraction.html) +* [Optimization (developer)](mllib-optimization.html) * stochastic gradient descent * limited-memory BFGS (L-BFGS) -MLlib is a new component under active development. +MLlib is under active development. The APIs marked `Experimental`/`DeveloperApi` may change in future releases, -and we will provide migration guide between releases. +and the migration guide below will explain all changes between releases. # Dependencies -MLlib uses linear algebra packages [Breeze](http://www.scalanlp.org/), which depends on +MLlib uses the linear algebra package [Breeze](http://www.scalanlp.org/), which depends on [netlib-java](https://github.com/fommil/netlib-java), and [jblas](https://github.com/mikiobraun/jblas). `netlib-java` and `jblas` depend on native Fortran routines. @@ -56,7 +58,7 @@ To use MLlib in Python, you will need [NumPy](http://www.numpy.org) version 1.4 In MLlib v1.0, we support both dense and sparse input in a unified way, which introduces a few breaking changes. If your data is sparse, please store it in a sparse format instead of dense to -take advantage of sparsity in both storage and computation. +take advantage of sparsity in both storage and computation. Details are described below.
    diff --git a/docs/mllib-linear-methods.md b/docs/mllib-linear-methods.md index 254201147edc1..e504cd7f0f578 100644 --- a/docs/mllib-linear-methods.md +++ b/docs/mllib-linear-methods.md @@ -33,24 +33,24 @@ the task of finding a minimizer of a convex function `$f$` that depends on a var Formally, we can write this as the optimization problem `$\min_{\wv \in\R^d} \; f(\wv)$`, where the objective function is of the form `\begin{equation} - f(\wv) := - \frac1n \sum_{i=1}^n L(\wv;\x_i,y_i) + - \lambda\, R(\wv_i) + f(\wv) := \lambda\, R(\wv) + + \frac1n \sum_{i=1}^n L(\wv;\x_i,y_i) \label{eq:regPrimal} \ . \end{equation}` Here the vectors `$\x_i\in\R^d$` are the training data examples, for `$1\le i\le n$`, and `$y_i\in\R$` are their corresponding labels, which we want to predict. We call the method *linear* if $L(\wv; \x, y)$ can be expressed as a function of $\wv^T x$ and $y$. -Several MLlib's classification and regression algorithms fall into this category, +Several of MLlib's classification and regression algorithms fall into this category, and are discussed here. The objective function `$f$` has two parts: -the loss that measures the error of the model on the training data, -and the regularizer that measures the complexity of the model. -The loss function `$L(\wv;.)$` must be a convex function in `$\wv$`. -The fixed regularization parameter `$\lambda \ge 0$` (`regParam` in the code) defines the trade-off -between the two goals of small loss and small model complexity. +the regularizer that controls the complexity of the model, +and the loss that measures the error of the model on the training data. +The loss function `$L(\wv;.)$` is typically a convex function in `$\wv$`. The +fixed regularization parameter `$\lambda \ge 0$` (`regParam` in the code) +defines the trade-off between the two goals of minimizing the loss (i.e., +training error) and minimizing model complexity (i.e., to avoid overfitting). ### Loss functions @@ -80,10 +80,10 @@ methods MLlib supports: ### Regularizers -The purpose of the [regularizer](http://en.wikipedia.org/wiki/Regularization_(mathematics)) is to -encourage simple models, by punishing the complexity of the model `$\wv$`, in order to e.g. avoid -over-fitting. -We support the following regularizers in MLlib: +The purpose of the +[regularizer](http://en.wikipedia.org/wiki/Regularization_(mathematics)) is to +encourage simple models and avoid overfitting. We support the following +regularizers in MLlib: @@ -106,27 +106,28 @@ Here `$\mathrm{sign}(\wv)$` is the vector consisting of the signs (`$\pm1$`) of of `$\wv$`. L2-regularized problems are generally easier to solve than L1-regularized due to smoothness. -However, L1 regularization can help promote sparsity in weights, leading to simpler models, which is -also used for feature selection. It is not recommended to train models without any regularization, +However, L1 regularization can help promote sparsity in weights leading to smaller and more interpretable models, the latter of which can be useful for feature selection. +It is not recommended to train models without any regularization, especially when the number of training examples is small. ## Binary classification -[Binary classification](http://en.wikipedia.org/wiki/Binary_classification) is to divide items into -two categories: positive and negative. MLlib supports two linear methods for binary classification: -linear support vector machine (SVM) and logistic regression. The training data set is represented -by an RDD of [LabeledPoint](mllib-data-types.html) in MLlib. Note that, in the mathematical -formulation, a training label $y$ is either $+1$ (positive) or $-1$ (negative), which is convenient -for the formulation. *However*, the negative label is represented by $0$ in MLlib instead of $-1$, -to be consistent with multiclass labeling. +[Binary classification](http://en.wikipedia.org/wiki/Binary_classification) +aims to divide items into two categories: positive and negative. MLlib +supports two linear methods for binary classification: linear support vector +machines (SVMs) and logistic regression. For both methods, MLlib supports +L1 and L2 regularized variants. The training data set is represented by an RDD +of [LabeledPoint](mllib-data-types.html) in MLlib. Note that, in the +mathematical formulation in this guide, a training label $y$ is denoted as +either $+1$ (positive) or $-1$ (negative), which is convenient for the +formulation. *However*, the negative label is represented by $0$ in MLlib +instead of $-1$, to be consistent with multiclass labeling. -### Linear support vector machine (SVM) +### Linear support vector machines (SVMs) The [linear SVM](http://en.wikipedia.org/wiki/Support_vector_machine#Linear_SVM) -has become a standard choice for large-scale classification tasks. -The name "linear SVM" is actually ambiguous. -By "linear SVM", we mean specifically the linear method with the loss function in formulation -`$\eqref{eq:regPrimal}$` given by the hinge loss +is a standard method for large-scale classification tasks. It is a linear method as described above in equation `$\eqref{eq:regPrimal}$`, with the loss function in the formulation given by the hinge loss: + `\[ L(\wv;\x,y) := \max \{0, 1-y \wv^T \x \}. \]` @@ -134,39 +135,44 @@ By default, linear SVMs are trained with an L2 regularization. We also support alternative L1 regularization. In this case, the problem becomes a [linear program](http://en.wikipedia.org/wiki/Linear_programming). -Linear SVM algorithm outputs a SVM model, which makes predictions based on the value of $\wv^T \x$. -By the default, if $\wv^T \x \geq 0$, the outcome is positive, or negative otherwise. -However, quite often in practice, the default threshold $0$ is not a good choice. -The threshold should be determined via model evaluation. +The linear SVMs algorithm outputs an SVM model. Given a new data point, +denoted by $\x$, the model makes predictions based on the value of $\wv^T \x$. +By the default, if $\wv^T \x \geq 0$ then the outcome is positive, and negative +otherwise. ### Logistic regression [Logistic regression](http://en.wikipedia.org/wiki/Logistic_regression) is widely used to predict a -binary response. It is a linear method with the loss function in formulation -`$\eqref{eq:regPrimal}$` given by the logistic loss +binary response. +It is a linear method as described above in equation `$\eqref{eq:regPrimal}$`, with the loss +function in the formulation given by the logistic loss: `\[ L(\wv;\x,y) := \log(1+\exp( -y \wv^T \x)). \]` -Logistic regression algorithm outputs a logistic regression model, which makes predictions by +The logistic regression algorithm outputs a logistic regression model. Given a +new data point, denoted by $\x$, the model makes predictions by applying the logistic function `\[ \mathrm{f}(z) = \frac{1}{1 + e^{-z}} \]` where $z = \wv^T \x$. -By default, if $\mathrm{f}(\wv^T x) > 0.5$, the outcome is positive, or negative otherwise. -For the same reason mentioned above, quite often in practice, this default threshold is not a good choice. -The threshold should be determined via model evaluation. +By default, if $\mathrm{f}(\wv^T x) > 0.5$, the outcome is positive, or +negative otherwise, though unlike linear SVMs, the raw output of the logistic regression +model, $\mathrm{f}(z)$, has a probabilistic interpretation (i.e., the probability +that $\x$ is positive). ### Evaluation metrics -MLlib supports common evaluation metrics for binary classification (not available in Python). This +MLlib supports common evaluation metrics for binary classification (not available in PySpark). +This includes precision, recall, [F-measure](http://en.wikipedia.org/wiki/F1_score), [receiver operating characteristic (ROC)](http://en.wikipedia.org/wiki/Receiver_operating_characteristic), precision-recall curve, and [area under the curves (AUC)](http://en.wikipedia.org/wiki/Receiver_operating_characteristic#Area_under_the_curve). -Among the metrics, area under ROC is commonly used to compare models and precision/recall/F-measure -can help determine the threshold to use. +AUC is commonly used to compare the performance of various models while +precision/recall/F-measure can help determine the appropriate threshold to use +for prediction purposes. ### Examples @@ -233,8 +239,7 @@ svmAlg.optimizer. val modelL1 = svmAlg.run(training) {% endhighlight %} -Similarly, you can use replace `SVMWithSGD` by -[`LogisticRegressionWithSGD`](api/scala/index.html#org.apache.spark.mllib.classification.LogisticRegressionWithSGD). +[`LogisticRegressionWithSGD`](api/scala/index.html#org.apache.spark.mllib.classification.LogisticRegressionWithSGD) can be used in a similar fashion as `SVMWithSGD`. @@ -318,10 +323,11 @@ svmAlg.optimizer() final SVMModel modelL1 = svmAlg.run(training.rdd()); {% endhighlight %} -In order to run the above standalone application using Spark framework make -sure that you follow the instructions provided at section [Standalone -Applications](quick-start.html) of the quick-start guide. What is more, you -should include to your build file *spark-mllib* as a dependency. +In order to run the above standalone application, follow the instructions +provided in the [Standalone +Applications](quick-start.html#standalone-applications) section of the Spark +quick-start guide. Be sure to also include *spark-mllib* to your build file as +a dependency.
    @@ -354,24 +360,22 @@ print("Training Error = " + str(trainErr)) ## Linear least squares, Lasso, and ridge regression -Linear least squares is a family of linear methods with the loss function in formulation -`$\eqref{eq:regPrimal}$` given by the squared loss +Linear least squares is the most common formulation for regression problems. +It is a linear method as described above in equation `$\eqref{eq:regPrimal}$`, with the loss +function in the formulation given by the squared loss: `\[ L(\wv;\x,y) := \frac{1}{2} (\wv^T \x - y)^2. \]` -Depending on the regularization type, we call the method -[*ordinary least squares*](http://en.wikipedia.org/wiki/Ordinary_least_squares) or simply -[*linear least squares*](http://en.wikipedia.org/wiki/Linear_least_squares_(mathematics)) if there -is no regularization, [*ridge regression*](http://en.wikipedia.org/wiki/Ridge_regression) if L2 -regularization is used, and [*Lasso*](http://en.wikipedia.org/wiki/Lasso_(statistics)) if L1 -regularization is used. This average loss $\frac{1}{n} \sum_{i=1}^n (\wv^T x_i - y_i)^2$ is also +Various related regression methods are derived by using different types of regularization: +[*ordinary least squares*](http://en.wikipedia.org/wiki/Ordinary_least_squares) or +[*linear least squares*](http://en.wikipedia.org/wiki/Linear_least_squares_(mathematics)) uses + no regularization; [*ridge regression*](http://en.wikipedia.org/wiki/Ridge_regression) uses L2 +regularization; and [*Lasso*](http://en.wikipedia.org/wiki/Lasso_(statistics)) uses L1 +regularization. For all of these models, the average loss or training error, $\frac{1}{n} \sum_{i=1}^n (\wv^T x_i - y_i)^2$, is known as the [mean squared error](http://en.wikipedia.org/wiki/Mean_squared_error). -Note that the squared loss is sensitive to outliers. -Regularization or a robust alternative (e.g., $\ell_1$ regression) is usually necessary in practice. - ### Examples
    @@ -379,7 +383,7 @@ Regularization or a robust alternative (e.g., $\ell_1$ regression) is usually ne
    The following example demonstrate how to load training data, parse it as an RDD of LabeledPoint. The example then uses LinearRegressionWithSGD to build a simple linear model to predict label -values. We compute the Mean Squared Error at the end to evaluate +values. We compute the mean squared error at the end to evaluate [goodness of fit](http://en.wikipedia.org/wiki/Goodness_of_fit). {% highlight scala %} @@ -407,9 +411,8 @@ val MSE = valuesAndPreds.map{case(v, p) => math.pow((v - p), 2)}.mean() println("training Mean Squared Error = " + MSE) {% endhighlight %} -Similarly you can use [`RidgeRegressionWithSGD`](api/scala/index.html#org.apache.spark.mllib.regression.RidgeRegressionWithSGD) -and [`LassoWithSGD`](api/scala/index.html#org.apache.spark.mllib.regression.LassoWithSGD). +and [`LassoWithSGD`](api/scala/index.html#org.apache.spark.mllib.regression.LassoWithSGD) can be used in a similar fashion as `LinearRegressionWithSGD`.
    @@ -479,16 +482,17 @@ public class LinearRegression { } {% endhighlight %} -In order to run the above standalone application using Spark framework make -sure that you follow the instructions provided at section [Standalone -Applications](quick-start.html) of the quick-start guide. What is more, you -should include to your build file *spark-mllib* as a dependency. +In order to run the above standalone application, follow the instructions +provided in the [Standalone +Applications](quick-start.html#standalone-applications) section of the Spark +quick-start guide. Be sure to also include *spark-mllib* to your build file as +a dependency.
    The following example demonstrate how to load training data, parse it as an RDD of LabeledPoint. The example then uses LinearRegressionWithSGD to build a simple linear model to predict label -values. We compute the Mean Squared Error at the end to evaluate +values. We compute the mean squared error at the end to evaluate [goodness of fit](http://en.wikipedia.org/wiki/Goodness_of_fit). {% highlight python %} diff --git a/docs/mllib-naive-bayes.md b/docs/mllib-naive-bayes.md index b1650c83c98b9..86d94aebd9442 100644 --- a/docs/mllib-naive-bayes.md +++ b/docs/mllib-naive-bayes.md @@ -4,23 +4,23 @@ title: Naive Bayes - MLlib displayTitle: MLlib - Naive Bayes --- -Naive Bayes is a simple multiclass classification algorithm with the assumption of independence -between every pair of features. Naive Bayes can be trained very efficiently. Within a single pass to -the training data, it computes the conditional probability distribution of each feature given label, -and then it applies Bayes' theorem to compute the conditional probability distribution of label -given an observation and use it for prediction. For more details, please visit the Wikipedia page -[Naive Bayes classifier](http://en.wikipedia.org/wiki/Naive_Bayes_classifier). - -In MLlib, we implemented multinomial naive Bayes, which is typically used for document -classification. Within that context, each observation is a document, each feature represents a term, -whose value is the frequency of the term. For its formulation, please visit the Wikipedia page -[Multinomial Naive Bayes](http://en.wikipedia.org/wiki/Naive_Bayes_classifier#Multinomial_naive_Bayes) -or the section -[Naive Bayes text classification](http://nlp.stanford.edu/IR-book/html/htmledition/naive-bayes-text-classification-1.html) -from the book Introduction to Information -Retrieval. [Additive smoothing](http://en.wikipedia.org/wiki/Lidstone_smoothing) can be used by +[Naive Bayes](http://en.wikipedia.org/wiki/Naive_Bayes_classifier) is a simple +multiclass classification algorithm with the assumption of independence between +every pair of features. Naive Bayes can be trained very efficiently. Within a +single pass to the training data, it computes the conditional probability +distribution of each feature given label, and then it applies Bayes' theorem to +compute the conditional probability distribution of label given an observation +and use it for prediction. + +MLlib supports [multinomial naive +Bayes](http://en.wikipedia.org/wiki/Naive_Bayes_classifier#Multinomial_naive_Bayes), +which is typically used for [document +classification](http://nlp.stanford.edu/IR-book/html/htmledition/naive-bayes-text-classification-1.html). +Within that context, each observation is a document and each +feature represents a term whose value is the frequency of the term. +[Additive smoothing](http://en.wikipedia.org/wiki/Lidstone_smoothing) can be used by setting the parameter $\lambda$ (default to $1.0$). For document classification, the input feature -vectors are usually sparse. Please supply sparse vectors as input to take advantage of +vectors are usually sparse, and sparse vectors should be supplied as input to take advantage of sparsity. Since the training data is only used once, it is not necessary to cache it. ## Examples diff --git a/docs/mllib-stats.md b/docs/mllib-stats.md new file mode 100644 index 0000000000000..ca9ef46c15186 --- /dev/null +++ b/docs/mllib-stats.md @@ -0,0 +1,95 @@ +--- +layout: global +title: Statistics Functionality - MLlib +displayTitle: MLlib - Statistics Functionality +--- + +* Table of contents +{:toc} + + +`\[ +\newcommand{\R}{\mathbb{R}} +\newcommand{\E}{\mathbb{E}} +\newcommand{\x}{\mathbf{x}} +\newcommand{\y}{\mathbf{y}} +\newcommand{\wv}{\mathbf{w}} +\newcommand{\av}{\mathbf{\alpha}} +\newcommand{\bv}{\mathbf{b}} +\newcommand{\N}{\mathbb{N}} +\newcommand{\id}{\mathbf{I}} +\newcommand{\ind}{\mathbf{1}} +\newcommand{\0}{\mathbf{0}} +\newcommand{\unit}{\mathbf{e}} +\newcommand{\one}{\mathbf{1}} +\newcommand{\zero}{\mathbf{0}} +\]` + +## Data Generators + +## Stratified Sampling + +## Summary Statistics + +### Multivariate summary statistics + +We provide column summary statistics for `RowMatrix` (note: this functionality is not currently supported in `IndexedRowMatrix` or `CoordinateMatrix`). +If the number of columns is not large, e.g., on the order of thousands, then the +covariance matrix can also be computed as a local matrix, which requires $\mathcal{O}(n^2)$ storage where $n$ is the +number of columns. The total CPU time is $\mathcal{O}(m n^2)$, where $m$ is the number of rows, +and is faster if the rows are sparse. + +
    +
    + +[`computeColumnSummaryStatistics()`](api/scala/index.html#org.apache.spark.mllib.linalg.distributed.RowMatrix) returns an instance of +[`MultivariateStatisticalSummary`](api/scala/index.html#org.apache.spark.mllib.stat.MultivariateStatisticalSummary), +which contains the column-wise max, min, mean, variance, and number of nonzeros, as well as the +total count. + +{% highlight scala %} +import org.apache.spark.mllib.linalg.Matrix +import org.apache.spark.mllib.linalg.distributed.RowMatrix +import org.apache.spark.mllib.stat.MultivariateStatisticalSummary + +val mat: RowMatrix = ... // a RowMatrix + +// Compute column summary statistics. +val summary: MultivariateStatisticalSummary = mat.computeColumnSummaryStatistics() +println(summary.mean) // a dense vector containing the mean value for each column +println(summary.variance) // column-wise variance +println(summary.numNonzeros) // number of nonzeros in each column + +// Compute the covariance matrix. +val cov: Matrix = mat.computeCovariance() +{% endhighlight %} +
    + +
    + +[`RowMatrix#computeColumnSummaryStatistics`](api/java/org/apache/spark/mllib/linalg/distributed/RowMatrix.html#computeColumnSummaryStatistics()) returns an instance of +[`MultivariateStatisticalSummary`](api/java/org/apache/spark/mllib/stat/MultivariateStatisticalSummary.html), +which contains the column-wise max, min, mean, variance, and number of nonzeros, as well as the +total count. + +{% highlight java %} +import org.apache.spark.mllib.linalg.Matrix; +import org.apache.spark.mllib.linalg.distributed.RowMatrix; +import org.apache.spark.mllib.stat.MultivariateStatisticalSummary; + +RowMatrix mat = ... // a RowMatrix + +// Compute column summary statistics. +MultivariateStatisticalSummary summary = mat.computeColumnSummaryStatistics(); +System.out.println(summary.mean()); // a dense vector containing the mean value for each column +System.out.println(summary.variance()); // column-wise variance +System.out.println(summary.numNonzeros()); // number of nonzeros in each column + +// Compute the covariance matrix. +Matrix cov = mat.computeCovariance(); +{% endhighlight %} +
    +
    + + +## Hypothesis Testing From 837bf60fd0e4597a50c917ad637d7fee4ff47a9a Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 12 Aug 2014 22:50:29 -0700 Subject: [PATCH 112/231] [SPARK-2953] Allow using short names for io compression codecs Instead of requiring "org.apache.spark.io.LZ4CompressionCodec", it is easier for users if Spark just accepts "lz4", "lzf", "snappy". Author: Reynold Xin Closes #1873 from rxin/compressionCodecShortForm and squashes the following commits: 9f50962 [Reynold Xin] Specify short-form compression codec names first. 63f78ee [Reynold Xin] Updated configuration documentation. 47b3848 [Reynold Xin] [SPARK-2953] Allow using short names for io compression codecs (cherry picked from commit 676f98289dad61c091bb45bd35a2b9613b22d64a) Signed-off-by: Reynold Xin --- .../org/apache/spark/io/CompressionCodec.scala | 11 +++++++++-- .../spark/io/CompressionCodecSuite.scala | 18 ++++++++++++++++++ docs/configuration.md | 8 +++++--- 3 files changed, 32 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala index 1b66218d86dd9..ef9c43ecf14f6 100644 --- a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala +++ b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala @@ -46,17 +46,24 @@ trait CompressionCodec { private[spark] object CompressionCodec { + + private val shortCompressionCodecNames = Map( + "lz4" -> classOf[LZ4CompressionCodec].getName, + "lzf" -> classOf[LZFCompressionCodec].getName, + "snappy" -> classOf[SnappyCompressionCodec].getName) + def createCodec(conf: SparkConf): CompressionCodec = { createCodec(conf, conf.get("spark.io.compression.codec", DEFAULT_COMPRESSION_CODEC)) } def createCodec(conf: SparkConf, codecName: String): CompressionCodec = { - val ctor = Class.forName(codecName, true, Utils.getContextOrSparkClassLoader) + val codecClass = shortCompressionCodecNames.getOrElse(codecName.toLowerCase, codecName) + val ctor = Class.forName(codecClass, true, Utils.getContextOrSparkClassLoader) .getConstructor(classOf[SparkConf]) ctor.newInstance(conf).asInstanceOf[CompressionCodec] } - val DEFAULT_COMPRESSION_CODEC = classOf[SnappyCompressionCodec].getName + val DEFAULT_COMPRESSION_CODEC = "snappy" } diff --git a/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala b/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala index 3f882a724b047..25be7f25c21bb 100644 --- a/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala +++ b/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala @@ -56,15 +56,33 @@ class CompressionCodecSuite extends FunSuite { testCodec(codec) } + test("lz4 compression codec short form") { + val codec = CompressionCodec.createCodec(conf, "lz4") + assert(codec.getClass === classOf[LZ4CompressionCodec]) + testCodec(codec) + } + test("lzf compression codec") { val codec = CompressionCodec.createCodec(conf, classOf[LZFCompressionCodec].getName) assert(codec.getClass === classOf[LZFCompressionCodec]) testCodec(codec) } + test("lzf compression codec short form") { + val codec = CompressionCodec.createCodec(conf, "lzf") + assert(codec.getClass === classOf[LZFCompressionCodec]) + testCodec(codec) + } + test("snappy compression codec") { val codec = CompressionCodec.createCodec(conf, classOf[SnappyCompressionCodec].getName) assert(codec.getClass === classOf[SnappyCompressionCodec]) testCodec(codec) } + + test("snappy compression codec short form") { + val codec = CompressionCodec.createCodec(conf, "snappy") + assert(codec.getClass === classOf[SnappyCompressionCodec]) + testCodec(codec) + } } diff --git a/docs/configuration.md b/docs/configuration.md index 617a72a021f6e..8136bd62ab6af 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -373,10 +373,12 @@ Apart from these, the following properties are also available, and may be useful
    - + From be674b34bed93eafeb621cbac5d5bb5f3a60e8f4 Mon Sep 17 00:00:00 2001 From: Raymond Liu Date: Tue, 12 Aug 2014 23:19:35 -0700 Subject: [PATCH 113/231] Use transferTo when copy merge files in ExternalSorter Since this is a file to file copy, using transferTo should be faster. Author: Raymond Liu Closes #1884 from colorant/externalSorter and squashes the following commits: 6e42f3c [Raymond Liu] More code into copyStream bfb496b [Raymond Liu] Use transferTo when copy merge files in ExternalSorter (cherry picked from commit 246cb3f158686348a698d1c0da3001c314727129) Signed-off-by: Reynold Xin --- .../scala/org/apache/spark/util/Utils.scala | 29 ++++++++++++++----- .../util/collection/ExternalSorter.scala | 7 ++--- 2 files changed, 25 insertions(+), 11 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index c60be4f8a11d2..8cac5da644fa9 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -284,17 +284,32 @@ private[spark] object Utils extends Logging { /** Copy all data from an InputStream to an OutputStream */ def copyStream(in: InputStream, out: OutputStream, - closeStreams: Boolean = false) + closeStreams: Boolean = false): Long = { + var count = 0L try { - val buf = new Array[Byte](8192) - var n = 0 - while (n != -1) { - n = in.read(buf) - if (n != -1) { - out.write(buf, 0, n) + if (in.isInstanceOf[FileInputStream] && out.isInstanceOf[FileOutputStream]) { + // When both streams are File stream, use transferTo to improve copy performance. + val inChannel = in.asInstanceOf[FileInputStream].getChannel() + val outChannel = out.asInstanceOf[FileOutputStream].getChannel() + val size = inChannel.size() + + // In case transferTo method transferred less data than we have required. + while (count < size) { + count += inChannel.transferTo(count, size - count, outChannel) + } + } else { + val buf = new Array[Byte](8192) + var n = 0 + while (n != -1) { + n = in.read(buf) + if (n != -1) { + out.write(buf, 0, n) + count += n + } } } + count } finally { if (closeStreams) { try { diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index b73d5e0cf1714..5d8a648d9551e 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -745,12 +745,11 @@ private[spark] class ExternalSorter[K, V, C]( try { out = new FileOutputStream(outputFile) for (i <- 0 until numPartitions) { - val file = partitionWriters(i).fileSegment().file - in = new FileInputStream(file) - org.apache.spark.util.Utils.copyStream(in, out) + in = new FileInputStream(partitionWriters(i).fileSegment().file) + val size = org.apache.spark.util.Utils.copyStream(in, out, false) in.close() in = null - lengths(i) = file.length() + lengths(i) = size offsets(i + 1) = offsets(i) + lengths(i) } } finally { From ec5e2b0d19233042894301eafdaaffcbc72356de Mon Sep 17 00:00:00 2001 From: "Zhang, Liye" Date: Tue, 12 Aug 2014 23:43:36 -0700 Subject: [PATCH 114/231] [SPARK-1777 (partial)] bugfix: make size of requested memory correctly Author: Zhang, Liye Closes #1892 from liyezhang556520/lazy_memory_request and squashes the following commits: 335ab61 [Zhang, Liye] [SPARK-1777 (partial)] bugfix: make size of requested memory correctly (cherry picked from commit 2bd812639c3d8c62a725fb7577365ef0816f2898) Signed-off-by: Reynold Xin --- .../src/main/scala/org/apache/spark/storage/MemoryStore.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala index 28f675c2bbb1e..0a09c24d61879 100644 --- a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala @@ -238,7 +238,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) // If our vector's size has exceeded the threshold, request more memory val currentSize = vector.estimateSize() if (currentSize >= memoryThreshold) { - val amountToRequest = (currentSize * (memoryGrowthFactor - 1)).toLong + val amountToRequest = (currentSize * memoryGrowthFactor - memoryThreshold).toLong // Hold the accounting lock, in case another thread concurrently puts a block that // takes up the unrolling space we just ensured here accountingLock.synchronized { @@ -254,7 +254,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) } } // New threshold is currentSize * memoryGrowthFactor - memoryThreshold = currentSize + amountToRequest + memoryThreshold += amountToRequest } } elementsUnrolled += 1 From 5ebeb3fdfa230dbb17b58e53b917a89856686212 Mon Sep 17 00:00:00 2001 From: Doris Xin Date: Tue, 12 Aug 2014 23:47:42 -0700 Subject: [PATCH 115/231] [SPARK-2993] [MLLib] colStats (wrapper around MultivariateStatisticalSummary) in Statistics For both Scala and Python. The ser/de util functions were moved out of `PythonMLLibAPI` and into their own object to avoid creating the `PythonMLLibAPI` object inside of `MultivariateStatisticalSummarySerialized`, which is then referenced inside of a method in `PythonMLLibAPI`. `MultivariateStatisticalSummarySerialized` was created to serialize the `Vector` fields in `MultivariateStatisticalSummary`. Author: Doris Xin Closes #1911 from dorx/colStats and squashes the following commits: 77b9924 [Doris Xin] developerAPI tag de9cbbe [Doris Xin] reviewer comments and moved more ser/de 459faba [Doris Xin] colStats in Statistics for both Scala and Python (cherry picked from commit fe4735958e62b1b32a01960503876000f3d2e520) Signed-off-by: Xiangrui Meng --- .../mllib/api/python/PythonMLLibAPI.scala | 532 ++++++++++-------- .../MatrixFactorizationModel.scala | 7 +- .../apache/spark/mllib/stat/Statistics.scala | 13 + .../api/python/PythonMLLibAPISuite.scala | 17 +- python/pyspark/mllib/stat.py | 66 ++- 5 files changed, 374 insertions(+), 261 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index ba7ccd8ce4b8b..18dc087856785 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -34,7 +34,7 @@ import org.apache.spark.mllib.tree.configuration.{Algo, Strategy} import org.apache.spark.mllib.tree.DecisionTree import org.apache.spark.mllib.tree.impurity._ import org.apache.spark.mllib.tree.model.DecisionTreeModel -import org.apache.spark.mllib.stat.Statistics +import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, Statistics} import org.apache.spark.mllib.stat.correlation.CorrelationNames import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD @@ -48,182 +48,7 @@ import org.apache.spark.util.Utils */ @DeveloperApi class PythonMLLibAPI extends Serializable { - private val DENSE_VECTOR_MAGIC: Byte = 1 - private val SPARSE_VECTOR_MAGIC: Byte = 2 - private val DENSE_MATRIX_MAGIC: Byte = 3 - private val LABELED_POINT_MAGIC: Byte = 4 - - private[python] def deserializeDoubleVector(bytes: Array[Byte], offset: Int = 0): Vector = { - require(bytes.length - offset >= 5, "Byte array too short") - val magic = bytes(offset) - if (magic == DENSE_VECTOR_MAGIC) { - deserializeDenseVector(bytes, offset) - } else if (magic == SPARSE_VECTOR_MAGIC) { - deserializeSparseVector(bytes, offset) - } else { - throw new IllegalArgumentException("Magic " + magic + " is wrong.") - } - } - - private[python] def deserializeDouble(bytes: Array[Byte], offset: Int = 0): Double = { - require(bytes.length - offset == 8, "Wrong size byte array for Double") - val bb = ByteBuffer.wrap(bytes, offset, bytes.length - offset) - bb.order(ByteOrder.nativeOrder()) - bb.getDouble - } - private def deserializeDenseVector(bytes: Array[Byte], offset: Int = 0): Vector = { - val packetLength = bytes.length - offset - require(packetLength >= 5, "Byte array too short") - val bb = ByteBuffer.wrap(bytes, offset, bytes.length - offset) - bb.order(ByteOrder.nativeOrder()) - val magic = bb.get() - require(magic == DENSE_VECTOR_MAGIC, "Invalid magic: " + magic) - val length = bb.getInt() - require (packetLength == 5 + 8 * length, "Invalid packet length: " + packetLength) - val db = bb.asDoubleBuffer() - val ans = new Array[Double](length.toInt) - db.get(ans) - Vectors.dense(ans) - } - - private def deserializeSparseVector(bytes: Array[Byte], offset: Int = 0): Vector = { - val packetLength = bytes.length - offset - require(packetLength >= 9, "Byte array too short") - val bb = ByteBuffer.wrap(bytes, offset, bytes.length - offset) - bb.order(ByteOrder.nativeOrder()) - val magic = bb.get() - require(magic == SPARSE_VECTOR_MAGIC, "Invalid magic: " + magic) - val size = bb.getInt() - val nonZeros = bb.getInt() - require (packetLength == 9 + 12 * nonZeros, "Invalid packet length: " + packetLength) - val ib = bb.asIntBuffer() - val indices = new Array[Int](nonZeros) - ib.get(indices) - bb.position(bb.position() + 4 * nonZeros) - val db = bb.asDoubleBuffer() - val values = new Array[Double](nonZeros) - db.get(values) - Vectors.sparse(size, indices, values) - } - - /** - * Returns an 8-byte array for the input Double. - * - * Note: we currently do not use a magic byte for double for storage efficiency. - * This should be reconsidered when we add Ser/De for other 8-byte types (e.g. Long), for safety. - * The corresponding deserializer, deserializeDouble, needs to be modified as well if the - * serialization scheme changes. - */ - private[python] def serializeDouble(double: Double): Array[Byte] = { - val bytes = new Array[Byte](8) - val bb = ByteBuffer.wrap(bytes) - bb.order(ByteOrder.nativeOrder()) - bb.putDouble(double) - bytes - } - - private def serializeDenseVector(doubles: Array[Double]): Array[Byte] = { - val len = doubles.length - val bytes = new Array[Byte](5 + 8 * len) - val bb = ByteBuffer.wrap(bytes) - bb.order(ByteOrder.nativeOrder()) - bb.put(DENSE_VECTOR_MAGIC) - bb.putInt(len) - val db = bb.asDoubleBuffer() - db.put(doubles) - bytes - } - - private def serializeSparseVector(vector: SparseVector): Array[Byte] = { - val nonZeros = vector.indices.length - val bytes = new Array[Byte](9 + 12 * nonZeros) - val bb = ByteBuffer.wrap(bytes) - bb.order(ByteOrder.nativeOrder()) - bb.put(SPARSE_VECTOR_MAGIC) - bb.putInt(vector.size) - bb.putInt(nonZeros) - val ib = bb.asIntBuffer() - ib.put(vector.indices) - bb.position(bb.position() + 4 * nonZeros) - val db = bb.asDoubleBuffer() - db.put(vector.values) - bytes - } - - private[python] def serializeDoubleVector(vector: Vector): Array[Byte] = vector match { - case s: SparseVector => - serializeSparseVector(s) - case _ => - serializeDenseVector(vector.toArray) - } - - private def deserializeDoubleMatrix(bytes: Array[Byte]): Array[Array[Double]] = { - val packetLength = bytes.length - if (packetLength < 9) { - throw new IllegalArgumentException("Byte array too short.") - } - val bb = ByteBuffer.wrap(bytes) - bb.order(ByteOrder.nativeOrder()) - val magic = bb.get() - if (magic != DENSE_MATRIX_MAGIC) { - throw new IllegalArgumentException("Magic " + magic + " is wrong.") - } - val rows = bb.getInt() - val cols = bb.getInt() - if (packetLength != 9 + 8 * rows * cols) { - throw new IllegalArgumentException("Size " + rows + "x" + cols + " is wrong.") - } - val db = bb.asDoubleBuffer() - val ans = new Array[Array[Double]](rows.toInt) - for (i <- 0 until rows.toInt) { - ans(i) = new Array[Double](cols.toInt) - db.get(ans(i)) - } - ans - } - - private def serializeDoubleMatrix(doubles: Array[Array[Double]]): Array[Byte] = { - val rows = doubles.length - var cols = 0 - if (rows > 0) { - cols = doubles(0).length - } - val bytes = new Array[Byte](9 + 8 * rows * cols) - val bb = ByteBuffer.wrap(bytes) - bb.order(ByteOrder.nativeOrder()) - bb.put(DENSE_MATRIX_MAGIC) - bb.putInt(rows) - bb.putInt(cols) - val db = bb.asDoubleBuffer() - for (i <- 0 until rows) { - db.put(doubles(i)) - } - bytes - } - - private[python] def serializeLabeledPoint(p: LabeledPoint): Array[Byte] = { - val fb = serializeDoubleVector(p.features) - val bytes = new Array[Byte](1 + 8 + fb.length) - val bb = ByteBuffer.wrap(bytes) - bb.order(ByteOrder.nativeOrder()) - bb.put(LABELED_POINT_MAGIC) - bb.putDouble(p.label) - bb.put(fb) - bytes - } - - private[python] def deserializeLabeledPoint(bytes: Array[Byte]): LabeledPoint = { - require(bytes.length >= 9, "Byte array too short") - val magic = bytes(0) - if (magic != LABELED_POINT_MAGIC) { - throw new IllegalArgumentException("Magic " + magic + " is wrong.") - } - val labelBytes = ByteBuffer.wrap(bytes, 1, 8) - labelBytes.order(ByteOrder.nativeOrder()) - val label = labelBytes.asDoubleBuffer().get(0) - LabeledPoint(label, deserializeDoubleVector(bytes, 9)) - } /** * Loads and serializes labeled points saved with `RDD#saveAsTextFile`. @@ -236,17 +61,17 @@ class PythonMLLibAPI extends Serializable { jsc: JavaSparkContext, path: String, minPartitions: Int): JavaRDD[Array[Byte]] = - MLUtils.loadLabeledPoints(jsc.sc, path, minPartitions).map(serializeLabeledPoint) + MLUtils.loadLabeledPoints(jsc.sc, path, minPartitions).map(SerDe.serializeLabeledPoint) private def trainRegressionModel( trainFunc: (RDD[LabeledPoint], Vector) => GeneralizedLinearModel, dataBytesJRDD: JavaRDD[Array[Byte]], initialWeightsBA: Array[Byte]): java.util.LinkedList[java.lang.Object] = { - val data = dataBytesJRDD.rdd.map(deserializeLabeledPoint) - val initialWeights = deserializeDoubleVector(initialWeightsBA) + val data = dataBytesJRDD.rdd.map(SerDe.deserializeLabeledPoint) + val initialWeights = SerDe.deserializeDoubleVector(initialWeightsBA) val model = trainFunc(data, initialWeights) val ret = new java.util.LinkedList[java.lang.Object]() - ret.add(serializeDoubleVector(model.weights)) + ret.add(SerDe.serializeDoubleVector(model.weights)) ret.add(model.intercept: java.lang.Double) ret } @@ -405,12 +230,12 @@ class PythonMLLibAPI extends Serializable { def trainNaiveBayes( dataBytesJRDD: JavaRDD[Array[Byte]], lambda: Double): java.util.List[java.lang.Object] = { - val data = dataBytesJRDD.rdd.map(deserializeLabeledPoint) + val data = dataBytesJRDD.rdd.map(SerDe.deserializeLabeledPoint) val model = NaiveBayes.train(data, lambda) val ret = new java.util.LinkedList[java.lang.Object]() - ret.add(serializeDoubleVector(Vectors.dense(model.labels))) - ret.add(serializeDoubleVector(Vectors.dense(model.pi))) - ret.add(serializeDoubleMatrix(model.theta)) + ret.add(SerDe.serializeDoubleVector(Vectors.dense(model.labels))) + ret.add(SerDe.serializeDoubleVector(Vectors.dense(model.pi))) + ret.add(SerDe.serializeDoubleMatrix(model.theta)) ret } @@ -423,52 +248,13 @@ class PythonMLLibAPI extends Serializable { maxIterations: Int, runs: Int, initializationMode: String): java.util.List[java.lang.Object] = { - val data = dataBytesJRDD.rdd.map(bytes => deserializeDoubleVector(bytes)) + val data = dataBytesJRDD.rdd.map(bytes => SerDe.deserializeDoubleVector(bytes)) val model = KMeans.train(data, k, maxIterations, runs, initializationMode) val ret = new java.util.LinkedList[java.lang.Object]() - ret.add(serializeDoubleMatrix(model.clusterCenters.map(_.toArray))) + ret.add(SerDe.serializeDoubleMatrix(model.clusterCenters.map(_.toArray))) ret } - /** Unpack a Rating object from an array of bytes */ - private def unpackRating(ratingBytes: Array[Byte]): Rating = { - val bb = ByteBuffer.wrap(ratingBytes) - bb.order(ByteOrder.nativeOrder()) - val user = bb.getInt() - val product = bb.getInt() - val rating = bb.getDouble() - new Rating(user, product, rating) - } - - /** Unpack a tuple of Ints from an array of bytes */ - private[spark] def unpackTuple(tupleBytes: Array[Byte]): (Int, Int) = { - val bb = ByteBuffer.wrap(tupleBytes) - bb.order(ByteOrder.nativeOrder()) - val v1 = bb.getInt() - val v2 = bb.getInt() - (v1, v2) - } - - /** - * Serialize a Rating object into an array of bytes. - * It can be deserialized using RatingDeserializer(). - * - * @param rate the Rating object to serialize - * @return - */ - private[spark] def serializeRating(rate: Rating): Array[Byte] = { - val len = 3 - val bytes = new Array[Byte](4 + 8 * len) - val bb = ByteBuffer.wrap(bytes) - bb.order(ByteOrder.nativeOrder()) - bb.putInt(len) - val db = bb.asDoubleBuffer() - db.put(rate.user.toDouble) - db.put(rate.product.toDouble) - db.put(rate.rating) - bytes - } - /** * Java stub for Python mllib ALS.train(). This stub returns a handle * to the Java object instead of the content of the Java object. Extra care @@ -481,7 +267,7 @@ class PythonMLLibAPI extends Serializable { iterations: Int, lambda: Double, blocks: Int): MatrixFactorizationModel = { - val ratings = ratingsBytesJRDD.rdd.map(unpackRating) + val ratings = ratingsBytesJRDD.rdd.map(SerDe.unpackRating) ALS.train(ratings, rank, iterations, lambda, blocks) } @@ -498,7 +284,7 @@ class PythonMLLibAPI extends Serializable { lambda: Double, blocks: Int, alpha: Double): MatrixFactorizationModel = { - val ratings = ratingsBytesJRDD.rdd.map(unpackRating) + val ratings = ratingsBytesJRDD.rdd.map(SerDe.unpackRating) ALS.trainImplicit(ratings, rank, iterations, lambda, blocks, alpha) } @@ -519,7 +305,7 @@ class PythonMLLibAPI extends Serializable { maxDepth: Int, maxBins: Int): DecisionTreeModel = { - val data = dataBytesJRDD.rdd.map(deserializeLabeledPoint) + val data = dataBytesJRDD.rdd.map(SerDe.deserializeLabeledPoint) val algo = Algo.fromString(algoStr) val impurity = Impurities.fromString(impurityStr) @@ -545,7 +331,7 @@ class PythonMLLibAPI extends Serializable { def predictDecisionTreeModel( model: DecisionTreeModel, featuresBytes: Array[Byte]): Double = { - val features: Vector = deserializeDoubleVector(featuresBytes) + val features: Vector = SerDe.deserializeDoubleVector(featuresBytes) model.predict(features) } @@ -559,8 +345,17 @@ class PythonMLLibAPI extends Serializable { def predictDecisionTreeModel( model: DecisionTreeModel, dataJRDD: JavaRDD[Array[Byte]]): JavaRDD[Array[Byte]] = { - val data = dataJRDD.rdd.map(xBytes => deserializeDoubleVector(xBytes)) - model.predict(data).map(serializeDouble) + val data = dataJRDD.rdd.map(xBytes => SerDe.deserializeDoubleVector(xBytes)) + model.predict(data).map(SerDe.serializeDouble) + } + + /** + * Java stub for mllib Statistics.colStats(X: RDD[Vector]). + * TODO figure out return type. + */ + def colStats(X: JavaRDD[Array[Byte]]): MultivariateStatisticalSummarySerialized = { + val cStats = Statistics.colStats(X.rdd.map(SerDe.deserializeDoubleVector(_))) + new MultivariateStatisticalSummarySerialized(cStats) } /** @@ -569,17 +364,17 @@ class PythonMLLibAPI extends Serializable { * pyspark. */ def corr(X: JavaRDD[Array[Byte]], method: String): Array[Byte] = { - val inputMatrix = X.rdd.map(deserializeDoubleVector(_)) + val inputMatrix = X.rdd.map(SerDe.deserializeDoubleVector(_)) val result = Statistics.corr(inputMatrix, getCorrNameOrDefault(method)) - serializeDoubleMatrix(to2dArray(result)) + SerDe.serializeDoubleMatrix(SerDe.to2dArray(result)) } /** * Java stub for mllib Statistics.corr(x: RDD[Double], y: RDD[Double], method: String). */ def corr(x: JavaRDD[Array[Byte]], y: JavaRDD[Array[Byte]], method: String): Double = { - val xDeser = x.rdd.map(deserializeDouble(_)) - val yDeser = y.rdd.map(deserializeDouble(_)) + val xDeser = x.rdd.map(SerDe.deserializeDouble(_)) + val yDeser = y.rdd.map(SerDe.deserializeDouble(_)) Statistics.corr(xDeser, yDeser, getCorrNameOrDefault(method)) } @@ -588,12 +383,6 @@ class PythonMLLibAPI extends Serializable { if (method == null) CorrelationNames.defaultCorrName else method } - // Reformat a Matrix into Array[Array[Double]] for serialization - private[python] def to2dArray(matrix: Matrix): Array[Array[Double]] = { - val values = matrix.toArray - Array.tabulate(matrix.numRows, matrix.numCols)((i, j) => values(i + j * matrix.numRows)) - } - // Used by the *RDD methods to get default seed if not passed in from pyspark private def getSeedOrDefault(seed: java.lang.Long): Long = { if (seed == null) Utils.random.nextLong else seed @@ -621,7 +410,7 @@ class PythonMLLibAPI extends Serializable { seed: java.lang.Long): JavaRDD[Array[Byte]] = { val parts = getNumPartitionsOrDefault(numPartitions, jsc) val s = getSeedOrDefault(seed) - RG.uniformRDD(jsc.sc, size, parts, s).map(serializeDouble) + RG.uniformRDD(jsc.sc, size, parts, s).map(SerDe.serializeDouble) } /** @@ -633,7 +422,7 @@ class PythonMLLibAPI extends Serializable { seed: java.lang.Long): JavaRDD[Array[Byte]] = { val parts = getNumPartitionsOrDefault(numPartitions, jsc) val s = getSeedOrDefault(seed) - RG.normalRDD(jsc.sc, size, parts, s).map(serializeDouble) + RG.normalRDD(jsc.sc, size, parts, s).map(SerDe.serializeDouble) } /** @@ -646,7 +435,7 @@ class PythonMLLibAPI extends Serializable { seed: java.lang.Long): JavaRDD[Array[Byte]] = { val parts = getNumPartitionsOrDefault(numPartitions, jsc) val s = getSeedOrDefault(seed) - RG.poissonRDD(jsc.sc, mean, size, parts, s).map(serializeDouble) + RG.poissonRDD(jsc.sc, mean, size, parts, s).map(SerDe.serializeDouble) } /** @@ -659,7 +448,7 @@ class PythonMLLibAPI extends Serializable { seed: java.lang.Long): JavaRDD[Array[Byte]] = { val parts = getNumPartitionsOrDefault(numPartitions, jsc) val s = getSeedOrDefault(seed) - RG.uniformVectorRDD(jsc.sc, numRows, numCols, parts, s).map(serializeDoubleVector) + RG.uniformVectorRDD(jsc.sc, numRows, numCols, parts, s).map(SerDe.serializeDoubleVector) } /** @@ -672,7 +461,7 @@ class PythonMLLibAPI extends Serializable { seed: java.lang.Long): JavaRDD[Array[Byte]] = { val parts = getNumPartitionsOrDefault(numPartitions, jsc) val s = getSeedOrDefault(seed) - RG.normalVectorRDD(jsc.sc, numRows, numCols, parts, s).map(serializeDoubleVector) + RG.normalVectorRDD(jsc.sc, numRows, numCols, parts, s).map(SerDe.serializeDoubleVector) } /** @@ -686,7 +475,256 @@ class PythonMLLibAPI extends Serializable { seed: java.lang.Long): JavaRDD[Array[Byte]] = { val parts = getNumPartitionsOrDefault(numPartitions, jsc) val s = getSeedOrDefault(seed) - RG.poissonVectorRDD(jsc.sc, mean, numRows, numCols, parts, s).map(serializeDoubleVector) + RG.poissonVectorRDD(jsc.sc, mean, numRows, numCols, parts, s).map(SerDe.serializeDoubleVector) + } + +} + +/** + * :: DeveloperApi :: + * MultivariateStatisticalSummary with Vector fields serialized. + */ +@DeveloperApi +class MultivariateStatisticalSummarySerialized(val summary: MultivariateStatisticalSummary) + extends Serializable { + + def mean: Array[Byte] = SerDe.serializeDoubleVector(summary.mean) + + def variance: Array[Byte] = SerDe.serializeDoubleVector(summary.variance) + + def count: Long = summary.count + + def numNonzeros: Array[Byte] = SerDe.serializeDoubleVector(summary.numNonzeros) + + def max: Array[Byte] = SerDe.serializeDoubleVector(summary.max) + + def min: Array[Byte] = SerDe.serializeDoubleVector(summary.min) +} + +/** + * SerDe utility functions for PythonMLLibAPI. + */ +private[spark] object SerDe extends Serializable { + private val DENSE_VECTOR_MAGIC: Byte = 1 + private val SPARSE_VECTOR_MAGIC: Byte = 2 + private val DENSE_MATRIX_MAGIC: Byte = 3 + private val LABELED_POINT_MAGIC: Byte = 4 + + private[python] def deserializeDoubleVector(bytes: Array[Byte], offset: Int = 0): Vector = { + require(bytes.length - offset >= 5, "Byte array too short") + val magic = bytes(offset) + if (magic == DENSE_VECTOR_MAGIC) { + deserializeDenseVector(bytes, offset) + } else if (magic == SPARSE_VECTOR_MAGIC) { + deserializeSparseVector(bytes, offset) + } else { + throw new IllegalArgumentException("Magic " + magic + " is wrong.") + } } + private[python] def deserializeDouble(bytes: Array[Byte], offset: Int = 0): Double = { + require(bytes.length - offset == 8, "Wrong size byte array for Double") + val bb = ByteBuffer.wrap(bytes, offset, bytes.length - offset) + bb.order(ByteOrder.nativeOrder()) + bb.getDouble + } + + private[python] def deserializeDenseVector(bytes: Array[Byte], offset: Int = 0): Vector = { + val packetLength = bytes.length - offset + require(packetLength >= 5, "Byte array too short") + val bb = ByteBuffer.wrap(bytes, offset, bytes.length - offset) + bb.order(ByteOrder.nativeOrder()) + val magic = bb.get() + require(magic == DENSE_VECTOR_MAGIC, "Invalid magic: " + magic) + val length = bb.getInt() + require (packetLength == 5 + 8 * length, "Invalid packet length: " + packetLength) + val db = bb.asDoubleBuffer() + val ans = new Array[Double](length.toInt) + db.get(ans) + Vectors.dense(ans) + } + + private[python] def deserializeSparseVector(bytes: Array[Byte], offset: Int = 0): Vector = { + val packetLength = bytes.length - offset + require(packetLength >= 9, "Byte array too short") + val bb = ByteBuffer.wrap(bytes, offset, bytes.length - offset) + bb.order(ByteOrder.nativeOrder()) + val magic = bb.get() + require(magic == SPARSE_VECTOR_MAGIC, "Invalid magic: " + magic) + val size = bb.getInt() + val nonZeros = bb.getInt() + require (packetLength == 9 + 12 * nonZeros, "Invalid packet length: " + packetLength) + val ib = bb.asIntBuffer() + val indices = new Array[Int](nonZeros) + ib.get(indices) + bb.position(bb.position() + 4 * nonZeros) + val db = bb.asDoubleBuffer() + val values = new Array[Double](nonZeros) + db.get(values) + Vectors.sparse(size, indices, values) + } + + /** + * Returns an 8-byte array for the input Double. + * + * Note: we currently do not use a magic byte for double for storage efficiency. + * This should be reconsidered when we add Ser/De for other 8-byte types (e.g. Long), for safety. + * The corresponding deserializer, deserializeDouble, needs to be modified as well if the + * serialization scheme changes. + */ + private[python] def serializeDouble(double: Double): Array[Byte] = { + val bytes = new Array[Byte](8) + val bb = ByteBuffer.wrap(bytes) + bb.order(ByteOrder.nativeOrder()) + bb.putDouble(double) + bytes + } + + private[python] def serializeDenseVector(doubles: Array[Double]): Array[Byte] = { + val len = doubles.length + val bytes = new Array[Byte](5 + 8 * len) + val bb = ByteBuffer.wrap(bytes) + bb.order(ByteOrder.nativeOrder()) + bb.put(DENSE_VECTOR_MAGIC) + bb.putInt(len) + val db = bb.asDoubleBuffer() + db.put(doubles) + bytes + } + + private[python] def serializeSparseVector(vector: SparseVector): Array[Byte] = { + val nonZeros = vector.indices.length + val bytes = new Array[Byte](9 + 12 * nonZeros) + val bb = ByteBuffer.wrap(bytes) + bb.order(ByteOrder.nativeOrder()) + bb.put(SPARSE_VECTOR_MAGIC) + bb.putInt(vector.size) + bb.putInt(nonZeros) + val ib = bb.asIntBuffer() + ib.put(vector.indices) + bb.position(bb.position() + 4 * nonZeros) + val db = bb.asDoubleBuffer() + db.put(vector.values) + bytes + } + + private[python] def serializeDoubleVector(vector: Vector): Array[Byte] = vector match { + case s: SparseVector => + serializeSparseVector(s) + case _ => + serializeDenseVector(vector.toArray) + } + + private[python] def deserializeDoubleMatrix(bytes: Array[Byte]): Array[Array[Double]] = { + val packetLength = bytes.length + if (packetLength < 9) { + throw new IllegalArgumentException("Byte array too short.") + } + val bb = ByteBuffer.wrap(bytes) + bb.order(ByteOrder.nativeOrder()) + val magic = bb.get() + if (magic != DENSE_MATRIX_MAGIC) { + throw new IllegalArgumentException("Magic " + magic + " is wrong.") + } + val rows = bb.getInt() + val cols = bb.getInt() + if (packetLength != 9 + 8 * rows * cols) { + throw new IllegalArgumentException("Size " + rows + "x" + cols + " is wrong.") + } + val db = bb.asDoubleBuffer() + val ans = new Array[Array[Double]](rows.toInt) + for (i <- 0 until rows.toInt) { + ans(i) = new Array[Double](cols.toInt) + db.get(ans(i)) + } + ans + } + + private[python] def serializeDoubleMatrix(doubles: Array[Array[Double]]): Array[Byte] = { + val rows = doubles.length + var cols = 0 + if (rows > 0) { + cols = doubles(0).length + } + val bytes = new Array[Byte](9 + 8 * rows * cols) + val bb = ByteBuffer.wrap(bytes) + bb.order(ByteOrder.nativeOrder()) + bb.put(DENSE_MATRIX_MAGIC) + bb.putInt(rows) + bb.putInt(cols) + val db = bb.asDoubleBuffer() + for (i <- 0 until rows) { + db.put(doubles(i)) + } + bytes + } + + private[python] def serializeLabeledPoint(p: LabeledPoint): Array[Byte] = { + val fb = serializeDoubleVector(p.features) + val bytes = new Array[Byte](1 + 8 + fb.length) + val bb = ByteBuffer.wrap(bytes) + bb.order(ByteOrder.nativeOrder()) + bb.put(LABELED_POINT_MAGIC) + bb.putDouble(p.label) + bb.put(fb) + bytes + } + + private[python] def deserializeLabeledPoint(bytes: Array[Byte]): LabeledPoint = { + require(bytes.length >= 9, "Byte array too short") + val magic = bytes(0) + if (magic != LABELED_POINT_MAGIC) { + throw new IllegalArgumentException("Magic " + magic + " is wrong.") + } + val labelBytes = ByteBuffer.wrap(bytes, 1, 8) + labelBytes.order(ByteOrder.nativeOrder()) + val label = labelBytes.asDoubleBuffer().get(0) + LabeledPoint(label, deserializeDoubleVector(bytes, 9)) + } + + // Reformat a Matrix into Array[Array[Double]] for serialization + private[python] def to2dArray(matrix: Matrix): Array[Array[Double]] = { + val values = matrix.toArray + Array.tabulate(matrix.numRows, matrix.numCols)((i, j) => values(i + j * matrix.numRows)) + } + + + /** Unpack a Rating object from an array of bytes */ + private[python] def unpackRating(ratingBytes: Array[Byte]): Rating = { + val bb = ByteBuffer.wrap(ratingBytes) + bb.order(ByteOrder.nativeOrder()) + val user = bb.getInt() + val product = bb.getInt() + val rating = bb.getDouble() + new Rating(user, product, rating) + } + + /** Unpack a tuple of Ints from an array of bytes */ + def unpackTuple(tupleBytes: Array[Byte]): (Int, Int) = { + val bb = ByteBuffer.wrap(tupleBytes) + bb.order(ByteOrder.nativeOrder()) + val v1 = bb.getInt() + val v2 = bb.getInt() + (v1, v2) + } + + /** + * Serialize a Rating object into an array of bytes. + * It can be deserialized using RatingDeserializer(). + * + * @param rate the Rating object to serialize + * @return + */ + def serializeRating(rate: Rating): Array[Byte] = { + val len = 3 + val bytes = new Array[Byte](4 + 8 * len) + val bb = ByteBuffer.wrap(bytes) + bb.order(ByteOrder.nativeOrder()) + bb.putInt(len) + val db = bb.asDoubleBuffer() + db.put(rate.user.toDouble) + db.put(rate.product.toDouble) + db.put(rate.rating) + bytes + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala index a1a76fcbe9f9c..478c6485052b6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala @@ -23,7 +23,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.api.java.JavaRDD import org.apache.spark.rdd.RDD import org.apache.spark.SparkContext._ -import org.apache.spark.mllib.api.python.PythonMLLibAPI +import org.apache.spark.mllib.api.python.SerDe /** * Model representing the result of matrix factorization. @@ -117,9 +117,8 @@ class MatrixFactorizationModel private[mllib] ( */ @DeveloperApi def predict(usersProductsJRDD: JavaRDD[Array[Byte]]): JavaRDD[Array[Byte]] = { - val pythonAPI = new PythonMLLibAPI() - val usersProducts = usersProductsJRDD.rdd.map(xBytes => pythonAPI.unpackTuple(xBytes)) - predict(usersProducts).map(rate => pythonAPI.serializeRating(rate)) + val usersProducts = usersProductsJRDD.rdd.map(xBytes => SerDe.unpackTuple(xBytes)) + predict(usersProducts).map(rate => SerDe.serializeRating(rate)) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala index cf8679610e191..3cf1028fbc725 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala @@ -18,6 +18,7 @@ package org.apache.spark.mllib.stat import org.apache.spark.annotation.Experimental +import org.apache.spark.mllib.linalg.distributed.RowMatrix import org.apache.spark.mllib.linalg.{Matrix, Vector} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.stat.correlation.Correlations @@ -30,6 +31,18 @@ import org.apache.spark.rdd.RDD @Experimental object Statistics { + /** + * :: Experimental :: + * Computes column-wise summary statistics for the input RDD[Vector]. + * + * @param X an RDD[Vector] for which column-wise summary statistics are to be computed. + * @return [[MultivariateStatisticalSummary]] object containing column-wise summary statistics. + */ + @Experimental + def colStats(X: RDD[Vector]): MultivariateStatisticalSummary = { + new RowMatrix(X).computeColumnSummaryStatistics() + } + /** * :: Experimental :: * Compute the Pearson correlation matrix for the input RDD of Vectors. diff --git a/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala index bd413a80f5107..092d67bbc5238 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala @@ -23,7 +23,6 @@ import org.apache.spark.mllib.linalg.{Matrices, Vectors} import org.apache.spark.mllib.regression.LabeledPoint class PythonMLLibAPISuite extends FunSuite { - val py = new PythonMLLibAPI test("vector serialization") { val vectors = Seq( @@ -34,8 +33,8 @@ class PythonMLLibAPISuite extends FunSuite { Vectors.sparse(1, Array.empty[Int], Array.empty[Double]), Vectors.sparse(2, Array(1), Array(-2.0))) vectors.foreach { v => - val bytes = py.serializeDoubleVector(v) - val u = py.deserializeDoubleVector(bytes) + val bytes = SerDe.serializeDoubleVector(v) + val u = SerDe.deserializeDoubleVector(bytes) assert(u.getClass === v.getClass) assert(u === v) } @@ -50,8 +49,8 @@ class PythonMLLibAPISuite extends FunSuite { LabeledPoint(1.0, Vectors.sparse(1, Array.empty[Int], Array.empty[Double])), LabeledPoint(-0.5, Vectors.sparse(2, Array(1), Array(-2.0)))) points.foreach { p => - val bytes = py.serializeLabeledPoint(p) - val q = py.deserializeLabeledPoint(bytes) + val bytes = SerDe.serializeLabeledPoint(p) + val q = SerDe.deserializeLabeledPoint(bytes) assert(q.label === p.label) assert(q.features.getClass === p.features.getClass) assert(q.features === p.features) @@ -60,8 +59,8 @@ class PythonMLLibAPISuite extends FunSuite { test("double serialization") { for (x <- List(123.0, -10.0, 0.0, Double.MaxValue, Double.MinValue, Double.NaN)) { - val bytes = py.serializeDouble(x) - val deser = py.deserializeDouble(bytes) + val bytes = SerDe.serializeDouble(x) + val deser = SerDe.deserializeDouble(bytes) // We use `equals` here for comparison because we cannot use `==` for NaN assert(x.equals(deser)) } @@ -70,14 +69,14 @@ class PythonMLLibAPISuite extends FunSuite { test("matrix to 2D array") { val values = Array[Double](0, 1.2, 3, 4.56, 7, 8) val matrix = Matrices.dense(2, 3, values) - val arr = py.to2dArray(matrix) + val arr = SerDe.to2dArray(matrix) val expected = Array(Array[Double](0, 3, 7), Array[Double](1.2, 4.56, 8)) assert(arr === expected) // Test conversion for empty matrix val empty = Array[Double]() val emptyMatrix = Matrices.dense(0, 0, empty) - val empty2D = py.to2dArray(emptyMatrix) + val empty2D = SerDe.to2dArray(emptyMatrix) assert(empty2D === Array[Array[Double]]()) } } diff --git a/python/pyspark/mllib/stat.py b/python/pyspark/mllib/stat.py index 982906b9d09f0..a73abc5ff90df 100644 --- a/python/pyspark/mllib/stat.py +++ b/python/pyspark/mllib/stat.py @@ -22,11 +22,75 @@ from pyspark.mllib._common import \ _get_unmangled_double_vector_rdd, _get_unmangled_rdd, \ _serialize_double, _serialize_double_vector, \ - _deserialize_double, _deserialize_double_matrix + _deserialize_double, _deserialize_double_matrix, _deserialize_double_vector + + +class MultivariateStatisticalSummary(object): + + """ + Trait for multivariate statistical summary of a data matrix. + """ + + def __init__(self, sc, java_summary): + """ + :param sc: Spark context + :param java_summary: Handle to Java summary object + """ + self._sc = sc + self._java_summary = java_summary + + def __del__(self): + self._sc._gateway.detach(self._java_summary) + + def mean(self): + return _deserialize_double_vector(self._java_summary.mean()) + + def variance(self): + return _deserialize_double_vector(self._java_summary.variance()) + + def count(self): + return self._java_summary.count() + + def numNonzeros(self): + return _deserialize_double_vector(self._java_summary.numNonzeros()) + + def max(self): + return _deserialize_double_vector(self._java_summary.max()) + + def min(self): + return _deserialize_double_vector(self._java_summary.min()) class Statistics(object): + @staticmethod + def colStats(X): + """ + Computes column-wise summary statistics for the input RDD[Vector]. + + >>> from linalg import Vectors + >>> rdd = sc.parallelize([Vectors.dense([2, 0, 0, -2]), + ... Vectors.dense([4, 5, 0, 3]), + ... Vectors.dense([6, 7, 0, 8])]) + >>> cStats = Statistics.colStats(rdd) + >>> cStats.mean() + array([ 4., 4., 0., 3.]) + >>> cStats.variance() + array([ 4., 13., 0., 25.]) + >>> cStats.count() + 3L + >>> cStats.numNonzeros() + array([ 3., 2., 0., 3.]) + >>> cStats.max() + array([ 6., 7., 0., 8.]) + >>> cStats.min() + array([ 2., 0., 0., -2.]) + """ + sc = X.ctx + Xser = _get_unmangled_double_vector_rdd(X) + cStats = sc._jvm.PythonMLLibAPI().colStats(Xser._jrdd) + return MultivariateStatisticalSummary(sc, cStats) + @staticmethod def corr(x, y=None, method=None): """ From 78f2f99f1a36c2d01ccf7a709bf19b1a1a0f53fc Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Wed, 13 Aug 2014 14:42:57 -0700 Subject: [PATCH 116/231] [SPARK-2963] [SQL] There no documentation about building to use HiveServer and CLI for SparkSQL Author: Kousuke Saruta Closes #1885 from sarutak/SPARK-2963 and squashes the following commits: ed53329 [Kousuke Saruta] Modified description and notaton of proper noun 07c59fc [Kousuke Saruta] Added a description about how to build to use HiveServer and CLI for SparkSQL to building-with-maven.md 6e6645a [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into SPARK-2963 c88fa93 [Kousuke Saruta] Added a description about building to use HiveServer and CLI for SparkSQL (cherry picked from commit 869f06c759c29b09c8dc72e0e4034c03f908ba30) Signed-off-by: Michael Armbrust --- README.md | 9 +++++++++ docs/building-with-maven.md | 9 +++++++++ 2 files changed, 18 insertions(+) diff --git a/README.md b/README.md index f87e07aa5cc90..a1a48f5bd0819 100644 --- a/README.md +++ b/README.md @@ -115,6 +115,15 @@ If your project is built with Maven, add this to your POM file's ` +## A Note About Thrift JDBC server and CLI for Spark SQL + +Spark SQL supports Thrift JDBC server and CLI. +See sql-programming-guide.md for more information about those features. +You can use those features by setting `-Phive-thriftserver` when building Spark as follows. + + $ sbt/sbt -Phive-thriftserver assembly + + ## Configuration Please refer to the [Configuration guide](http://spark.apache.org/docs/latest/configuration.html) diff --git a/docs/building-with-maven.md b/docs/building-with-maven.md index 672d0ef114f6d..4d87ab92cec5b 100644 --- a/docs/building-with-maven.md +++ b/docs/building-with-maven.md @@ -96,6 +96,15 @@ mvn -Pyarn -Phadoop-2.4 -Dhadoop.version=2.4.0 -DskipTests clean package mvn -Pyarn-alpha -Phadoop-2.3 -Dhadoop.version=2.3.0 -Dyarn.version=0.23.7 -DskipTests clean package {% endhighlight %} +# Building Thrift JDBC server and CLI for Spark SQL + +Spark SQL supports Thrift JDBC server and CLI. +See sql-programming-guide.md for more information about those features. +You can use those features by setting `-Phive-thriftserver` when building Spark as follows. +{% highlight bash %} +mvn -Phive-thriftserver assembly +{% endhighlight %} + # Spark Tests in Maven Tests are run by default via the [ScalaTest Maven plugin](http://www.scalatest.org/user_guide/using_the_scalatest_maven_plugin). From 99360208792cb68aca6d26258be6c679c58f1cc8 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 13 Aug 2014 14:56:11 -0700 Subject: [PATCH 117/231] [SPARK-3013] [SQL] [PySpark] convert array into list because Pyrolite does not support array from Python 2.6 Author: Davies Liu Closes #1928 from davies/fix_array and squashes the following commits: 858e6c5 [Davies Liu] convert array into list (cherry picked from commit c974a716e17c9fe2628b1ba1d4309ead1bd855ad) Signed-off-by: Michael Armbrust --- python/pyspark/sql.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 27f1d2ddf942a..46540ca3f1e8a 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -498,10 +498,7 @@ def _infer_schema(row): def _create_converter(obj, dataType): """Create an converter to drop the names of fields in obj """ - if not _has_struct(dataType): - return lambda x: x - - elif isinstance(dataType, ArrayType): + if isinstance(dataType, ArrayType): conv = _create_converter(obj[0], dataType.elementType) return lambda row: map(conv, row) @@ -510,6 +507,9 @@ def _create_converter(obj, dataType): conv = _create_converter(value, dataType.valueType) return lambda row: dict((k, conv(v)) for k, v in row.iteritems()) + elif not isinstance(dataType, StructType): + return lambda x: x + # dataType must be StructType names = [f.name for f in dataType.fields] @@ -529,8 +529,7 @@ def _create_converter(obj, dataType): elif hasattr(obj, "__dict__"): # object conv = lambda o: [o.__dict__.get(n, None) for n in names] - nested = any(_has_struct(f.dataType) for f in dataType.fields) - if not nested: + if all(isinstance(f.dataType, PrimitiveType) for f in dataType.fields): return conv row = conv(obj) @@ -1037,7 +1036,8 @@ def inferSchema(self, rdd): raise ValueError("The first row in RDD is empty, " "can not infer schema") if type(first) is dict: - warnings.warn("Using RDD of dict to inferSchema is deprecated") + warnings.warn("Using RDD of dict to inferSchema is deprecated," + "please use pyspark.Row instead") schema = _infer_schema(first) rdd = rdd.mapPartitions(lambda rows: _drop_schema(rows, schema)) From a7bc21ca7a36f6f0d9004c742bbcd23367e1ecc3 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 13 Aug 2014 14:57:12 -0700 Subject: [PATCH 118/231] [SPARK-2983] [PySpark] improve performance of sortByKey() 1. skip partitionBy() when numOfPartition is 1 2. use bisect_left (O(lg(N))) instread of loop (O(N)) in rangePartitioner Author: Davies Liu Closes #1898 from davies/sort and squashes the following commits: 0a9608b [Davies Liu] Merge branch 'master' into sort 1cf9565 [Davies Liu] improve performance of sortByKey() (cherry picked from commit 434bea1c002b597cff9db899da101490e1f1e9ed) Signed-off-by: Matei Zaharia --- python/pyspark/rdd.py | 47 ++++++++++++++++++++++--------------------- 1 file changed, 24 insertions(+), 23 deletions(-) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 756e8f35fb03d..3934bdda0a466 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -30,6 +30,7 @@ from threading import Thread import warnings import heapq +import bisect from random import Random from math import sqrt, log @@ -574,6 +575,8 @@ def sortByKey(self, ascending=True, numPartitions=None, keyfunc=lambda x: x): # noqa >>> tmp = [('a', 1), ('b', 2), ('1', 3), ('d', 4), ('2', 5)] + >>> sc.parallelize(tmp).sortByKey(True, 1).collect() + [('1', 3), ('2', 5), ('a', 1), ('b', 2), ('d', 4)] >>> sc.parallelize(tmp).sortByKey(True, 2).collect() [('1', 3), ('2', 5), ('a', 1), ('b', 2), ('d', 4)] >>> tmp2 = [('Mary', 1), ('had', 2), ('a', 3), ('little', 4), ('lamb', 5)] @@ -584,42 +587,40 @@ def sortByKey(self, ascending=True, numPartitions=None, keyfunc=lambda x: x): if numPartitions is None: numPartitions = self._defaultReducePartitions() - bounds = list() + if numPartitions == 1: + if self.getNumPartitions() > 1: + self = self.coalesce(1) + + def sort(iterator): + return sorted(iterator, reverse=(not ascending), key=lambda (k, v): keyfunc(k)) + + return self.mapPartitions(sort) # first compute the boundary of each part via sampling: we want to partition # the key-space into bins such that the bins have roughly the same # number of (key, value) pairs falling into them - if numPartitions > 1: - rddSize = self.count() - # constant from Spark's RangePartitioner - maxSampleSize = numPartitions * 20.0 - fraction = min(maxSampleSize / max(rddSize, 1), 1.0) - - samples = self.sample(False, fraction, 1).map( - lambda (k, v): k).collect() - samples = sorted(samples, reverse=(not ascending), key=keyfunc) - - # we have numPartitions many parts but one of the them has - # an implicit boundary - for i in range(0, numPartitions - 1): - index = (len(samples) - 1) * (i + 1) / numPartitions - bounds.append(samples[index]) + rddSize = self.count() + maxSampleSize = numPartitions * 20.0 # constant from Spark's RangePartitioner + fraction = min(maxSampleSize / max(rddSize, 1), 1.0) + samples = self.sample(False, fraction, 1).map(lambda (k, v): k).collect() + samples = sorted(samples, reverse=(not ascending), key=keyfunc) + + # we have numPartitions many parts but one of the them has + # an implicit boundary + bounds = [samples[len(samples) * (i + 1) / numPartitions] + for i in range(0, numPartitions - 1)] def rangePartitionFunc(k): - p = 0 - while p < len(bounds) and keyfunc(k) > bounds[p]: - p += 1 + p = bisect.bisect_left(bounds, keyfunc(k)) if ascending: return p else: return numPartitions - 1 - p def mapFunc(iterator): - yield sorted(iterator, reverse=(not ascending), key=lambda (k, v): keyfunc(k)) + return sorted(iterator, reverse=(not ascending), key=lambda (k, v): keyfunc(k)) - return (self.partitionBy(numPartitions, partitionFunc=rangePartitionFunc) - .mapPartitions(mapFunc, preservesPartitioning=True) - .flatMap(lambda x: x, preservesPartitioning=True)) + return self.partitionBy(numPartitions, rangePartitionFunc).mapPartitions(mapFunc, True) def sortBy(self, keyfunc, ascending=True, numPartitions=None): """ From e63bf87099b95f261ed09cf90d20e564f0500798 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Wed, 13 Aug 2014 16:20:49 -0700 Subject: [PATCH 119/231] [MLLIB] use Iterator.fill instead of Array.fill Iterator.fill uses less memory Author: Xiangrui Meng Closes #1930 from mengxr/rand-gen-iter and squashes the following commits: 24178ca [Xiangrui Meng] use Iterator.fill instead of Array.fill (cherry picked from commit 7ecb867c4cd6916b6cb12f2ece1a4c88591ad5b5) Signed-off-by: Xiangrui Meng --- .../scala/org/apache/spark/mllib/rdd/RandomRDD.scala | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/rdd/RandomRDD.scala b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RandomRDD.scala index c8db3910c6eab..910eff9540a47 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/rdd/RandomRDD.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RandomRDD.scala @@ -105,16 +105,16 @@ private[mllib] object RandomRDD { def getPointIterator[T: ClassTag](partition: RandomRDDPartition[T]): Iterator[T] = { val generator = partition.generator.copy() generator.setSeed(partition.seed) - Array.fill(partition.size)(generator.nextValue()).toIterator + Iterator.fill(partition.size)(generator.nextValue()) } // The RNG has to be reset every time the iterator is requested to guarantee same data // every time the content of the RDD is examined. - def getVectorIterator(partition: RandomRDDPartition[Double], - vectorSize: Int): Iterator[Vector] = { + def getVectorIterator( + partition: RandomRDDPartition[Double], + vectorSize: Int): Iterator[Vector] = { val generator = partition.generator.copy() generator.setSeed(partition.seed) - Array.fill(partition.size)(new DenseVector( - (0 until vectorSize).map { _ => generator.nextValue() }.toArray)).toIterator + Iterator.fill(partition.size)(new DenseVector(Array.fill(vectorSize)(generator.nextValue()))) } } From 8732375e65b7191fb0e44fd91f200cae99d840ec Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Wed, 13 Aug 2014 16:27:50 -0700 Subject: [PATCH 120/231] [SPARK-3004][SQL] Added null checking when retrieving row set JIRA issue: [SPARK-3004](https://issues.apache.org/jira/browse/SPARK-3004) HiveThriftServer2 throws exception when the result set contains `NULL`. Should check `isNullAt` in `SparkSQLOperationManager.getNextRowSet`. Note that simply using `row.addColumnValue(null)` doesn't work, since Hive set the column type of a null `ColumnValue` to String by default. Author: Cheng Lian Closes #1920 from liancheng/spark-3004 and squashes the following commits: 1b1db1c [Cheng Lian] Adding NULL column values in the Hive way 2217722 [Cheng Lian] Fixed SPARK-3004: added null checking when retrieving row set (cherry picked from commit bdc7a1a4749301f8d18617c130c7766684aa8789) Signed-off-by: Michael Armbrust --- .../server/SparkSQLOperationManager.scala | 93 +++++++++++++------ .../data/files/small_kv_with_null.txt | 10 ++ .../thriftserver/HiveThriftServer2Suite.scala | 26 +++++- 3 files changed, 96 insertions(+), 33 deletions(-) create mode 100644 sql/hive-thriftserver/src/test/resources/data/files/small_kv_with_null.txt diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala index f192f490ac3d0..9338e8121b0fe 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala @@ -73,35 +73,10 @@ class SparkSQLOperationManager(hiveContext: HiveContext) extends OperationManage var curCol = 0 while (curCol < sparkRow.length) { - dataTypes(curCol) match { - case StringType => - row.addString(sparkRow(curCol).asInstanceOf[String]) - case IntegerType => - row.addColumnValue(ColumnValue.intValue(sparkRow.getInt(curCol))) - case BooleanType => - row.addColumnValue(ColumnValue.booleanValue(sparkRow.getBoolean(curCol))) - case DoubleType => - row.addColumnValue(ColumnValue.doubleValue(sparkRow.getDouble(curCol))) - case FloatType => - row.addColumnValue(ColumnValue.floatValue(sparkRow.getFloat(curCol))) - case DecimalType => - val hiveDecimal = sparkRow.get(curCol).asInstanceOf[BigDecimal].bigDecimal - row.addColumnValue(ColumnValue.stringValue(new HiveDecimal(hiveDecimal))) - case LongType => - row.addColumnValue(ColumnValue.longValue(sparkRow.getLong(curCol))) - case ByteType => - row.addColumnValue(ColumnValue.byteValue(sparkRow.getByte(curCol))) - case ShortType => - row.addColumnValue(ColumnValue.intValue(sparkRow.getShort(curCol))) - case TimestampType => - row.addColumnValue( - ColumnValue.timestampValue(sparkRow.get(curCol).asInstanceOf[Timestamp])) - case BinaryType | _: ArrayType | _: StructType | _: MapType => - val hiveString = result - .queryExecution - .asInstanceOf[HiveContext#QueryExecution] - .toHiveString((sparkRow.get(curCol), dataTypes(curCol))) - row.addColumnValue(ColumnValue.stringValue(hiveString)) + if (sparkRow.isNullAt(curCol)) { + addNullColumnValue(sparkRow, row, curCol) + } else { + addNonNullColumnValue(sparkRow, row, curCol) } curCol += 1 } @@ -112,6 +87,66 @@ class SparkSQLOperationManager(hiveContext: HiveContext) extends OperationManage } } + def addNonNullColumnValue(from: SparkRow, to: Row, ordinal: Int) { + dataTypes(ordinal) match { + case StringType => + to.addString(from(ordinal).asInstanceOf[String]) + case IntegerType => + to.addColumnValue(ColumnValue.intValue(from.getInt(ordinal))) + case BooleanType => + to.addColumnValue(ColumnValue.booleanValue(from.getBoolean(ordinal))) + case DoubleType => + to.addColumnValue(ColumnValue.doubleValue(from.getDouble(ordinal))) + case FloatType => + to.addColumnValue(ColumnValue.floatValue(from.getFloat(ordinal))) + case DecimalType => + val hiveDecimal = from.get(ordinal).asInstanceOf[BigDecimal].bigDecimal + to.addColumnValue(ColumnValue.stringValue(new HiveDecimal(hiveDecimal))) + case LongType => + to.addColumnValue(ColumnValue.longValue(from.getLong(ordinal))) + case ByteType => + to.addColumnValue(ColumnValue.byteValue(from.getByte(ordinal))) + case ShortType => + to.addColumnValue(ColumnValue.intValue(from.getShort(ordinal))) + case TimestampType => + to.addColumnValue( + ColumnValue.timestampValue(from.get(ordinal).asInstanceOf[Timestamp])) + case BinaryType | _: ArrayType | _: StructType | _: MapType => + val hiveString = result + .queryExecution + .asInstanceOf[HiveContext#QueryExecution] + .toHiveString((from.get(ordinal), dataTypes(ordinal))) + to.addColumnValue(ColumnValue.stringValue(hiveString)) + } + } + + def addNullColumnValue(from: SparkRow, to: Row, ordinal: Int) { + dataTypes(ordinal) match { + case StringType => + to.addString(null) + case IntegerType => + to.addColumnValue(ColumnValue.intValue(null)) + case BooleanType => + to.addColumnValue(ColumnValue.booleanValue(null)) + case DoubleType => + to.addColumnValue(ColumnValue.doubleValue(null)) + case FloatType => + to.addColumnValue(ColumnValue.floatValue(null)) + case DecimalType => + to.addColumnValue(ColumnValue.stringValue(null: HiveDecimal)) + case LongType => + to.addColumnValue(ColumnValue.longValue(null)) + case ByteType => + to.addColumnValue(ColumnValue.byteValue(null)) + case ShortType => + to.addColumnValue(ColumnValue.intValue(null)) + case TimestampType => + to.addColumnValue(ColumnValue.timestampValue(null)) + case BinaryType | _: ArrayType | _: StructType | _: MapType => + to.addColumnValue(ColumnValue.stringValue(null: String)) + } + } + def getResultSetSchema: TableSchema = { logWarning(s"Result Schema: ${result.queryExecution.analyzed.output}") if (result.queryExecution.analyzed.output.size == 0) { diff --git a/sql/hive-thriftserver/src/test/resources/data/files/small_kv_with_null.txt b/sql/hive-thriftserver/src/test/resources/data/files/small_kv_with_null.txt new file mode 100644 index 0000000000000..ae08c640e6c13 --- /dev/null +++ b/sql/hive-thriftserver/src/test/resources/data/files/small_kv_with_null.txt @@ -0,0 +1,10 @@ +238val_238 + +311val_311 +val_27 +val_165 +val_409 +255val_255 +278val_278 +98val_98 +val_484 diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala index 78bffa2607349..aedef6ce1f5f2 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala @@ -113,22 +113,40 @@ class HiveThriftServer2Suite extends FunSuite with BeforeAndAfterAll with TestUt val stmt = createStatement() stmt.execute("DROP TABLE IF EXISTS test") stmt.execute("DROP TABLE IF EXISTS test_cached") - stmt.execute("CREATE TABLE test(key int, val string)") + stmt.execute("CREATE TABLE test(key INT, val STRING)") stmt.execute(s"LOAD DATA LOCAL INPATH '$dataFilePath' OVERWRITE INTO TABLE test") - stmt.execute("CREATE TABLE test_cached as select * from test limit 4") + stmt.execute("CREATE TABLE test_cached AS SELECT * FROM test LIMIT 4") stmt.execute("CACHE TABLE test_cached") - var rs = stmt.executeQuery("select count(*) from test") + var rs = stmt.executeQuery("SELECT COUNT(*) FROM test") rs.next() assert(rs.getInt(1) === 5) - rs = stmt.executeQuery("select count(*) from test_cached") + rs = stmt.executeQuery("SELECT COUNT(*) FROM test_cached") rs.next() assert(rs.getInt(1) === 4) stmt.close() } + test("SPARK-3004 regression: result set containing NULL") { + Thread.sleep(5 * 1000) + val dataFilePath = getDataFile("data/files/small_kv_with_null.txt") + val stmt = createStatement() + stmt.execute("DROP TABLE IF EXISTS test_null") + stmt.execute("CREATE TABLE test_null(key INT, val STRING)") + stmt.execute(s"LOAD DATA LOCAL INPATH '$dataFilePath' OVERWRITE INTO TABLE test_null") + + val rs = stmt.executeQuery("SELECT * FROM test_null WHERE key IS NULL") + var count = 0 + while (rs.next()) { + count += 1 + } + assert(count === 5) + + stmt.close() + } + def getConnection: Connection = { val connectURI = s"jdbc:hive2://localhost:$PORT/" DriverManager.getConnection(connectURI, System.getProperty("user.name"), "") From 0fb1198fb9a0318b927857243eda972d336d2d8d Mon Sep 17 00:00:00 2001 From: tianyi Date: Wed, 13 Aug 2014 16:50:02 -0700 Subject: [PATCH 121/231] [SPARK-2817] [SQL] add "show create table" support In spark sql component, the "show create table" syntax had been disabled. We thought it is a useful funciton to describe a hive table. Author: tianyi Author: tianyi Author: tianyi Closes #1760 from tianyi/spark-2817 and squashes the following commits: 7d28b15 [tianyi] [SPARK-2817] fix too short prefix problem cbffe8b [tianyi] [SPARK-2817] fix the case problem 565ec14 [tianyi] [SPARK-2817] fix the case problem 60d48a9 [tianyi] [SPARK-2817] use system temporary folder instead of temporary files in the source tree, and also clean some empty line dbe1031 [tianyi] [SPARK-2817] move some code out of function rewritePaths, as it may be called multiple times 9b2ba11 [tianyi] [SPARK-2817] fix the line length problem 9f97586 [tianyi] [SPARK-2817] remove test.tmp.dir from pom.xml bfc2999 [tianyi] [SPARK-2817] add "File.separator" support, create a "testTmpDir" outside the rewritePaths bde800a [tianyi] [SPARK-2817] add "${system:test.tmp.dir}" support add "last_modified_by" to nonDeterministicLineIndicators in HiveComparisonTest bb82726 [tianyi] [SPARK-2817] remove test which requires a system from the whitelist. bbf6b42 [tianyi] [SPARK-2817] add a systemProperties named "test.tmp.dir" to pass the test which contains "${system:test.tmp.dir}" a337bd6 [tianyi] [SPARK-2817] add "show create table" support a03db77 [tianyi] [SPARK-2817] add "show create table" support (cherry picked from commit 13f54e2b97744beab45e1bdbcdf8d215ca481b78) Signed-off-by: Michael Armbrust --- .../execution/HiveCompatibilitySuite.scala | 8 +++++++ .../org/apache/spark/sql/hive/HiveQl.scala | 1 + .../org/apache/spark/sql/hive/TestHive.scala | 8 +++++++ ...e_alter-0-813886d6cf0875c62e89cd1d06b8b0b4 | 0 ...e_alter-1-2a91d52719cf4552ebeb867204552a26 | 18 +++++++++++++++ ..._alter-10-259d978ed9543204c8b9c25b6e25b0de | 0 ...e_alter-2-928cc85c025440b731e5ee33e437e404 | 0 ...e_alter-3-2a91d52719cf4552ebeb867204552a26 | 22 +++++++++++++++++++ ...e_alter-4-c2cb6a7d942d4dddd1aababccb1239f9 | 0 ...e_alter-5-2a91d52719cf4552ebeb867204552a26 | 21 ++++++++++++++++++ ...le_alter-6-fdd1bd7f9acf0b2c8c9b7503d4046cb | 0 ...e_alter-7-2a91d52719cf4552ebeb867204552a26 | 21 ++++++++++++++++++ ...e_alter-8-22ab6ed5b15a018756f454dd2294847e | 0 ...e_alter-9-2a91d52719cf4552ebeb867204552a26 | 21 ++++++++++++++++++ ...b_table-0-67509558a4b2d39b25787cca33f52635 | 0 ...b_table-1-549981e00a3d95f03dd5a9ef6044aa20 | 2 ++ ...db_table-2-34ae7e611d0aedbc62b6e420347abee | 0 ...b_table-3-7a9e67189d3d4151f23b12c22bde06b5 | 0 ...b_table-4-b585371b624cbab2616a49f553a870a0 | 13 +++++++++++ ...b_table-5-964757b7e7f2a69fe36132c1a5712199 | 0 ...b_table-6-ac09cf81e7e734cf10406f30b9fa566e | 0 ...limited-0-97228478b9925f06726ceebb6571bf34 | 0 ...limited-1-2a91d52719cf4552ebeb867204552a26 | 17 ++++++++++++++ ...limited-2-259d978ed9543204c8b9c25b6e25b0de | 0 ...itioned-0-4be9a3b1ff0840786a1f001cba170a0c | 0 ...itioned-1-2a91d52719cf4552ebeb867204552a26 | 16 ++++++++++++++ ...itioned-2-259d978ed9543204c8b9c25b6e25b0de | 0 ...e_serde-0-33f15d91810b75ee05c7b9dea0abb01c | 0 ...e_serde-1-2a91d52719cf4552ebeb867204552a26 | 15 +++++++++++++ ...e_serde-2-259d978ed9543204c8b9c25b6e25b0de | 0 ...e_serde-3-fd12b3e0fe30f5d71c67676791b4a33b | 0 ...e_serde-4-2a91d52719cf4552ebeb867204552a26 | 14 ++++++++++++ ...e_serde-5-259d978ed9543204c8b9c25b6e25b0de | 0 ...le_view-0-ecef6821e4e9212e553ca38142fd0250 | 0 ...le_view-1-1e931ea3fa6065107859ffbb29bb0ed7 | 1 + ...le_view-2-ed97e9e56d95c5b3db57485cba5ad17f | 0 .../hive/execution/HiveComparisonTest.scala | 1 + 37 files changed, 199 insertions(+) create mode 100644 sql/hive/src/test/resources/golden/show_create_table_alter-0-813886d6cf0875c62e89cd1d06b8b0b4 create mode 100644 sql/hive/src/test/resources/golden/show_create_table_alter-1-2a91d52719cf4552ebeb867204552a26 create mode 100644 sql/hive/src/test/resources/golden/show_create_table_alter-10-259d978ed9543204c8b9c25b6e25b0de create mode 100644 sql/hive/src/test/resources/golden/show_create_table_alter-2-928cc85c025440b731e5ee33e437e404 create mode 100644 sql/hive/src/test/resources/golden/show_create_table_alter-3-2a91d52719cf4552ebeb867204552a26 create mode 100644 sql/hive/src/test/resources/golden/show_create_table_alter-4-c2cb6a7d942d4dddd1aababccb1239f9 create mode 100644 sql/hive/src/test/resources/golden/show_create_table_alter-5-2a91d52719cf4552ebeb867204552a26 create mode 100644 sql/hive/src/test/resources/golden/show_create_table_alter-6-fdd1bd7f9acf0b2c8c9b7503d4046cb create mode 100644 sql/hive/src/test/resources/golden/show_create_table_alter-7-2a91d52719cf4552ebeb867204552a26 create mode 100644 sql/hive/src/test/resources/golden/show_create_table_alter-8-22ab6ed5b15a018756f454dd2294847e create mode 100644 sql/hive/src/test/resources/golden/show_create_table_alter-9-2a91d52719cf4552ebeb867204552a26 create mode 100644 sql/hive/src/test/resources/golden/show_create_table_db_table-0-67509558a4b2d39b25787cca33f52635 create mode 100644 sql/hive/src/test/resources/golden/show_create_table_db_table-1-549981e00a3d95f03dd5a9ef6044aa20 create mode 100644 sql/hive/src/test/resources/golden/show_create_table_db_table-2-34ae7e611d0aedbc62b6e420347abee create mode 100644 sql/hive/src/test/resources/golden/show_create_table_db_table-3-7a9e67189d3d4151f23b12c22bde06b5 create mode 100644 sql/hive/src/test/resources/golden/show_create_table_db_table-4-b585371b624cbab2616a49f553a870a0 create mode 100644 sql/hive/src/test/resources/golden/show_create_table_db_table-5-964757b7e7f2a69fe36132c1a5712199 create mode 100644 sql/hive/src/test/resources/golden/show_create_table_db_table-6-ac09cf81e7e734cf10406f30b9fa566e create mode 100644 sql/hive/src/test/resources/golden/show_create_table_delimited-0-97228478b9925f06726ceebb6571bf34 create mode 100644 sql/hive/src/test/resources/golden/show_create_table_delimited-1-2a91d52719cf4552ebeb867204552a26 create mode 100644 sql/hive/src/test/resources/golden/show_create_table_delimited-2-259d978ed9543204c8b9c25b6e25b0de create mode 100644 sql/hive/src/test/resources/golden/show_create_table_partitioned-0-4be9a3b1ff0840786a1f001cba170a0c create mode 100644 sql/hive/src/test/resources/golden/show_create_table_partitioned-1-2a91d52719cf4552ebeb867204552a26 create mode 100644 sql/hive/src/test/resources/golden/show_create_table_partitioned-2-259d978ed9543204c8b9c25b6e25b0de create mode 100644 sql/hive/src/test/resources/golden/show_create_table_serde-0-33f15d91810b75ee05c7b9dea0abb01c create mode 100644 sql/hive/src/test/resources/golden/show_create_table_serde-1-2a91d52719cf4552ebeb867204552a26 create mode 100644 sql/hive/src/test/resources/golden/show_create_table_serde-2-259d978ed9543204c8b9c25b6e25b0de create mode 100644 sql/hive/src/test/resources/golden/show_create_table_serde-3-fd12b3e0fe30f5d71c67676791b4a33b create mode 100644 sql/hive/src/test/resources/golden/show_create_table_serde-4-2a91d52719cf4552ebeb867204552a26 create mode 100644 sql/hive/src/test/resources/golden/show_create_table_serde-5-259d978ed9543204c8b9c25b6e25b0de create mode 100644 sql/hive/src/test/resources/golden/show_create_table_view-0-ecef6821e4e9212e553ca38142fd0250 create mode 100644 sql/hive/src/test/resources/golden/show_create_table_view-1-1e931ea3fa6065107859ffbb29bb0ed7 create mode 100644 sql/hive/src/test/resources/golden/show_create_table_view-2-ed97e9e56d95c5b3db57485cba5ad17f diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index 4fef071161719..210753efe7678 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -635,6 +635,14 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "serde_regex", "serde_reported_schema", "set_variable_sub", + "show_create_table_partitioned", + "show_create_table_delimited", + "show_create_table_alter", + "show_create_table_view", + "show_create_table_serde", + "show_create_table_db_table", + "show_create_table_does_not_exist", + "show_create_table_index", "show_describe_func_quotes", "show_functions", "show_partitions", diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 05b2f5f6cd3f7..1d9ba1b24a7a4 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -53,6 +53,7 @@ private[hive] object HiveQl { protected val nativeCommands = Seq( "TOK_DESCFUNCTION", "TOK_DESCDATABASE", + "TOK_SHOW_CREATETABLE", "TOK_SHOW_TABLESTATUS", "TOK_SHOWDATABASES", "TOK_SHOWFUNCTIONS", diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala index d890df866fbe5..a013f3f7a805f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala @@ -70,6 +70,13 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { setConf("hive.metastore.warehouse.dir", warehousePath) } + val testTempDir = File.createTempFile("testTempFiles", "spark.hive.tmp") + testTempDir.delete() + testTempDir.mkdir() + + // For some hive test case which contain ${system:test.tmp.dir} + System.setProperty("test.tmp.dir", testTempDir.getCanonicalPath) + configure() // Must be called before initializing the catalog below. /** The location of the compiled hive distribution */ @@ -109,6 +116,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { hiveFilesTemp.mkdir() hiveFilesTemp.deleteOnExit() + val inRepoTests = if (System.getProperty("user.dir").endsWith("sql" + File.separator + "hive")) { new File("src" + File.separator + "test" + File.separator + "resources" + File.separator) } else { diff --git a/sql/hive/src/test/resources/golden/show_create_table_alter-0-813886d6cf0875c62e89cd1d06b8b0b4 b/sql/hive/src/test/resources/golden/show_create_table_alter-0-813886d6cf0875c62e89cd1d06b8b0b4 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/show_create_table_alter-1-2a91d52719cf4552ebeb867204552a26 b/sql/hive/src/test/resources/golden/show_create_table_alter-1-2a91d52719cf4552ebeb867204552a26 new file mode 100644 index 0000000000000..3c1fc128bedce --- /dev/null +++ b/sql/hive/src/test/resources/golden/show_create_table_alter-1-2a91d52719cf4552ebeb867204552a26 @@ -0,0 +1,18 @@ +CREATE EXTERNAL TABLE tmp_showcrt1( + key smallint, + value float) +CLUSTERED BY ( + key) +SORTED BY ( + value DESC) +INTO 5 BUCKETS +ROW FORMAT SERDE + 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe' +STORED AS INPUTFORMAT + 'org.apache.hadoop.mapred.TextInputFormat' +OUTPUTFORMAT + 'org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat' +LOCATION + 'file:/tmp/sparkHiveWarehouse1280221975983654134/tmp_showcrt1' +TBLPROPERTIES ( + 'transient_lastDdlTime'='1407132100') diff --git a/sql/hive/src/test/resources/golden/show_create_table_alter-10-259d978ed9543204c8b9c25b6e25b0de b/sql/hive/src/test/resources/golden/show_create_table_alter-10-259d978ed9543204c8b9c25b6e25b0de new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/show_create_table_alter-2-928cc85c025440b731e5ee33e437e404 b/sql/hive/src/test/resources/golden/show_create_table_alter-2-928cc85c025440b731e5ee33e437e404 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/show_create_table_alter-3-2a91d52719cf4552ebeb867204552a26 b/sql/hive/src/test/resources/golden/show_create_table_alter-3-2a91d52719cf4552ebeb867204552a26 new file mode 100644 index 0000000000000..2ece813dd7d56 --- /dev/null +++ b/sql/hive/src/test/resources/golden/show_create_table_alter-3-2a91d52719cf4552ebeb867204552a26 @@ -0,0 +1,22 @@ +CREATE TABLE tmp_showcrt1( + key smallint, + value float) +COMMENT 'temporary table' +CLUSTERED BY ( + key) +SORTED BY ( + value DESC) +INTO 5 BUCKETS +ROW FORMAT SERDE + 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe' +STORED AS INPUTFORMAT + 'org.apache.hadoop.mapred.TextInputFormat' +OUTPUTFORMAT + 'org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat' +LOCATION + 'file:/tmp/sparkHiveWarehouse1280221975983654134/tmp_showcrt1' +TBLPROPERTIES ( + 'EXTERNAL'='FALSE', + 'last_modified_by'='tianyi', + 'last_modified_time'='1407132100', + 'transient_lastDdlTime'='1407132100') diff --git a/sql/hive/src/test/resources/golden/show_create_table_alter-4-c2cb6a7d942d4dddd1aababccb1239f9 b/sql/hive/src/test/resources/golden/show_create_table_alter-4-c2cb6a7d942d4dddd1aababccb1239f9 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/show_create_table_alter-5-2a91d52719cf4552ebeb867204552a26 b/sql/hive/src/test/resources/golden/show_create_table_alter-5-2a91d52719cf4552ebeb867204552a26 new file mode 100644 index 0000000000000..2af657bd29506 --- /dev/null +++ b/sql/hive/src/test/resources/golden/show_create_table_alter-5-2a91d52719cf4552ebeb867204552a26 @@ -0,0 +1,21 @@ +CREATE EXTERNAL TABLE tmp_showcrt1( + key smallint, + value float) +COMMENT 'changed comment' +CLUSTERED BY ( + key) +SORTED BY ( + value DESC) +INTO 5 BUCKETS +ROW FORMAT SERDE + 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe' +STORED AS INPUTFORMAT + 'org.apache.hadoop.mapred.TextInputFormat' +OUTPUTFORMAT + 'org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat' +LOCATION + 'file:/tmp/sparkHiveWarehouse1280221975983654134/tmp_showcrt1' +TBLPROPERTIES ( + 'last_modified_by'='tianyi', + 'last_modified_time'='1407132100', + 'transient_lastDdlTime'='1407132100') diff --git a/sql/hive/src/test/resources/golden/show_create_table_alter-6-fdd1bd7f9acf0b2c8c9b7503d4046cb b/sql/hive/src/test/resources/golden/show_create_table_alter-6-fdd1bd7f9acf0b2c8c9b7503d4046cb new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/show_create_table_alter-7-2a91d52719cf4552ebeb867204552a26 b/sql/hive/src/test/resources/golden/show_create_table_alter-7-2a91d52719cf4552ebeb867204552a26 new file mode 100644 index 0000000000000..f793ffb7a0bfd --- /dev/null +++ b/sql/hive/src/test/resources/golden/show_create_table_alter-7-2a91d52719cf4552ebeb867204552a26 @@ -0,0 +1,21 @@ +CREATE EXTERNAL TABLE tmp_showcrt1( + key smallint, + value float) +COMMENT 'changed comment' +CLUSTERED BY ( + key) +SORTED BY ( + value DESC) +INTO 5 BUCKETS +ROW FORMAT SERDE + 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe' +STORED AS INPUTFORMAT + 'org.apache.hadoop.mapred.TextInputFormat' +OUTPUTFORMAT + 'org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat' +LOCATION + 'file:/tmp/sparkHiveWarehouse1280221975983654134/tmp_showcrt1' +TBLPROPERTIES ( + 'last_modified_by'='tianyi', + 'last_modified_time'='1407132101', + 'transient_lastDdlTime'='1407132101') diff --git a/sql/hive/src/test/resources/golden/show_create_table_alter-8-22ab6ed5b15a018756f454dd2294847e b/sql/hive/src/test/resources/golden/show_create_table_alter-8-22ab6ed5b15a018756f454dd2294847e new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/show_create_table_alter-9-2a91d52719cf4552ebeb867204552a26 b/sql/hive/src/test/resources/golden/show_create_table_alter-9-2a91d52719cf4552ebeb867204552a26 new file mode 100644 index 0000000000000..c65aff26a7fc1 --- /dev/null +++ b/sql/hive/src/test/resources/golden/show_create_table_alter-9-2a91d52719cf4552ebeb867204552a26 @@ -0,0 +1,21 @@ +CREATE EXTERNAL TABLE tmp_showcrt1( + key smallint, + value float) +COMMENT 'changed comment' +CLUSTERED BY ( + key) +SORTED BY ( + value DESC) +INTO 5 BUCKETS +ROW FORMAT SERDE + 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe' +STORED BY + 'org.apache.hadoop.hive.ql.metadata.DefaultStorageHandler' +WITH SERDEPROPERTIES ( + 'serialization.format'='1') +LOCATION + 'file:/tmp/sparkHiveWarehouse1280221975983654134/tmp_showcrt1' +TBLPROPERTIES ( + 'last_modified_by'='tianyi', + 'last_modified_time'='1407132101', + 'transient_lastDdlTime'='1407132101') diff --git a/sql/hive/src/test/resources/golden/show_create_table_db_table-0-67509558a4b2d39b25787cca33f52635 b/sql/hive/src/test/resources/golden/show_create_table_db_table-0-67509558a4b2d39b25787cca33f52635 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/show_create_table_db_table-1-549981e00a3d95f03dd5a9ef6044aa20 b/sql/hive/src/test/resources/golden/show_create_table_db_table-1-549981e00a3d95f03dd5a9ef6044aa20 new file mode 100644 index 0000000000000..707b2ae3ed1df --- /dev/null +++ b/sql/hive/src/test/resources/golden/show_create_table_db_table-1-549981e00a3d95f03dd5a9ef6044aa20 @@ -0,0 +1,2 @@ +default +tmp_feng diff --git a/sql/hive/src/test/resources/golden/show_create_table_db_table-2-34ae7e611d0aedbc62b6e420347abee b/sql/hive/src/test/resources/golden/show_create_table_db_table-2-34ae7e611d0aedbc62b6e420347abee new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/show_create_table_db_table-3-7a9e67189d3d4151f23b12c22bde06b5 b/sql/hive/src/test/resources/golden/show_create_table_db_table-3-7a9e67189d3d4151f23b12c22bde06b5 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/show_create_table_db_table-4-b585371b624cbab2616a49f553a870a0 b/sql/hive/src/test/resources/golden/show_create_table_db_table-4-b585371b624cbab2616a49f553a870a0 new file mode 100644 index 0000000000000..b5a18368ed85e --- /dev/null +++ b/sql/hive/src/test/resources/golden/show_create_table_db_table-4-b585371b624cbab2616a49f553a870a0 @@ -0,0 +1,13 @@ +CREATE TABLE tmp_feng.tmp_showcrt( + key string, + value int) +ROW FORMAT SERDE + 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe' +STORED AS INPUTFORMAT + 'org.apache.hadoop.mapred.TextInputFormat' +OUTPUTFORMAT + 'org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat' +LOCATION + 'file:/tmp/sparkHiveWarehouse1280221975983654134/tmp_feng.db/tmp_showcrt' +TBLPROPERTIES ( + 'transient_lastDdlTime'='1407132107') diff --git a/sql/hive/src/test/resources/golden/show_create_table_db_table-5-964757b7e7f2a69fe36132c1a5712199 b/sql/hive/src/test/resources/golden/show_create_table_db_table-5-964757b7e7f2a69fe36132c1a5712199 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/show_create_table_db_table-6-ac09cf81e7e734cf10406f30b9fa566e b/sql/hive/src/test/resources/golden/show_create_table_db_table-6-ac09cf81e7e734cf10406f30b9fa566e new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/show_create_table_delimited-0-97228478b9925f06726ceebb6571bf34 b/sql/hive/src/test/resources/golden/show_create_table_delimited-0-97228478b9925f06726ceebb6571bf34 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/show_create_table_delimited-1-2a91d52719cf4552ebeb867204552a26 b/sql/hive/src/test/resources/golden/show_create_table_delimited-1-2a91d52719cf4552ebeb867204552a26 new file mode 100644 index 0000000000000..d36ad25dc8273 --- /dev/null +++ b/sql/hive/src/test/resources/golden/show_create_table_delimited-1-2a91d52719cf4552ebeb867204552a26 @@ -0,0 +1,17 @@ +CREATE TABLE tmp_showcrt1( + key int, + value string, + newvalue bigint) +ROW FORMAT DELIMITED + FIELDS TERMINATED BY ',' + COLLECTION ITEMS TERMINATED BY '|' + MAP KEYS TERMINATED BY '%' + LINES TERMINATED BY '\n' +STORED AS INPUTFORMAT + 'org.apache.hadoop.mapred.TextInputFormat' +OUTPUTFORMAT + 'org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat' +LOCATION + 'file:/tmp/tmp_showcrt1' +TBLPROPERTIES ( + 'transient_lastDdlTime'='1407132730') diff --git a/sql/hive/src/test/resources/golden/show_create_table_delimited-2-259d978ed9543204c8b9c25b6e25b0de b/sql/hive/src/test/resources/golden/show_create_table_delimited-2-259d978ed9543204c8b9c25b6e25b0de new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/show_create_table_partitioned-0-4be9a3b1ff0840786a1f001cba170a0c b/sql/hive/src/test/resources/golden/show_create_table_partitioned-0-4be9a3b1ff0840786a1f001cba170a0c new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/show_create_table_partitioned-1-2a91d52719cf4552ebeb867204552a26 b/sql/hive/src/test/resources/golden/show_create_table_partitioned-1-2a91d52719cf4552ebeb867204552a26 new file mode 100644 index 0000000000000..9e572c0d7df6a --- /dev/null +++ b/sql/hive/src/test/resources/golden/show_create_table_partitioned-1-2a91d52719cf4552ebeb867204552a26 @@ -0,0 +1,16 @@ +CREATE EXTERNAL TABLE tmp_showcrt1( + key string, + newvalue boolean COMMENT 'a new value') +COMMENT 'temporary table' +PARTITIONED BY ( + value bigint COMMENT 'some value') +ROW FORMAT SERDE + 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe' +STORED AS INPUTFORMAT + 'org.apache.hadoop.mapred.TextInputFormat' +OUTPUTFORMAT + 'org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat' +LOCATION + 'file:/tmp/sparkHiveWarehouse1280221975983654134/tmp_showcrt1' +TBLPROPERTIES ( + 'transient_lastDdlTime'='1407132112') diff --git a/sql/hive/src/test/resources/golden/show_create_table_partitioned-2-259d978ed9543204c8b9c25b6e25b0de b/sql/hive/src/test/resources/golden/show_create_table_partitioned-2-259d978ed9543204c8b9c25b6e25b0de new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/show_create_table_serde-0-33f15d91810b75ee05c7b9dea0abb01c b/sql/hive/src/test/resources/golden/show_create_table_serde-0-33f15d91810b75ee05c7b9dea0abb01c new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/show_create_table_serde-1-2a91d52719cf4552ebeb867204552a26 b/sql/hive/src/test/resources/golden/show_create_table_serde-1-2a91d52719cf4552ebeb867204552a26 new file mode 100644 index 0000000000000..69a38e1a7b20a --- /dev/null +++ b/sql/hive/src/test/resources/golden/show_create_table_serde-1-2a91d52719cf4552ebeb867204552a26 @@ -0,0 +1,15 @@ +CREATE TABLE tmp_showcrt1( + key int, + value string, + newvalue bigint) +COMMENT 'temporary table' +ROW FORMAT SERDE + 'org.apache.hadoop.hive.serde2.columnar.ColumnarSerDe' +STORED AS INPUTFORMAT + 'org.apache.hadoop.hive.ql.io.RCFileInputFormat' +OUTPUTFORMAT + 'org.apache.hadoop.hive.ql.io.RCFileOutputFormat' +LOCATION + 'file:/tmp/sparkHiveWarehouse1280221975983654134/tmp_showcrt1' +TBLPROPERTIES ( + 'transient_lastDdlTime'='1407132115') diff --git a/sql/hive/src/test/resources/golden/show_create_table_serde-2-259d978ed9543204c8b9c25b6e25b0de b/sql/hive/src/test/resources/golden/show_create_table_serde-2-259d978ed9543204c8b9c25b6e25b0de new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/show_create_table_serde-3-fd12b3e0fe30f5d71c67676791b4a33b b/sql/hive/src/test/resources/golden/show_create_table_serde-3-fd12b3e0fe30f5d71c67676791b4a33b new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/show_create_table_serde-4-2a91d52719cf4552ebeb867204552a26 b/sql/hive/src/test/resources/golden/show_create_table_serde-4-2a91d52719cf4552ebeb867204552a26 new file mode 100644 index 0000000000000..b4e693dc622fb --- /dev/null +++ b/sql/hive/src/test/resources/golden/show_create_table_serde-4-2a91d52719cf4552ebeb867204552a26 @@ -0,0 +1,14 @@ +CREATE EXTERNAL TABLE tmp_showcrt1( + key string, + value boolean) +ROW FORMAT SERDE + 'org.apache.hadoop.hive.serde2.columnar.ColumnarSerDe' +STORED BY + 'org.apache.hadoop.hive.ql.metadata.DefaultStorageHandler' +WITH SERDEPROPERTIES ( + 'serialization.format'='$', + 'field.delim'=',') +LOCATION + 'file:/tmp/sparkHiveWarehouse1280221975983654134/tmp_showcrt1' +TBLPROPERTIES ( + 'transient_lastDdlTime'='1407132115') diff --git a/sql/hive/src/test/resources/golden/show_create_table_serde-5-259d978ed9543204c8b9c25b6e25b0de b/sql/hive/src/test/resources/golden/show_create_table_serde-5-259d978ed9543204c8b9c25b6e25b0de new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/show_create_table_view-0-ecef6821e4e9212e553ca38142fd0250 b/sql/hive/src/test/resources/golden/show_create_table_view-0-ecef6821e4e9212e553ca38142fd0250 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/show_create_table_view-1-1e931ea3fa6065107859ffbb29bb0ed7 b/sql/hive/src/test/resources/golden/show_create_table_view-1-1e931ea3fa6065107859ffbb29bb0ed7 new file mode 100644 index 0000000000000..be3fb3ce30960 --- /dev/null +++ b/sql/hive/src/test/resources/golden/show_create_table_view-1-1e931ea3fa6065107859ffbb29bb0ed7 @@ -0,0 +1 @@ +CREATE VIEW tmp_copy_src AS SELECT `src`.`key`, `src`.`value` FROM `default`.`src` diff --git a/sql/hive/src/test/resources/golden/show_create_table_view-2-ed97e9e56d95c5b3db57485cba5ad17f b/sql/hive/src/test/resources/golden/show_create_table_view-2-ed97e9e56d95c5b3db57485cba5ad17f new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index 0ebaf6ffd5458..502ce8fb297e9 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -161,6 +161,7 @@ abstract class HiveComparisonTest "transient_lastDdlTime", "grantTime", "lastUpdateTime", + "last_modified_by", "last_modified_time", "Owner:", // The following are hive specific schema parameters which we do not need to match exactly. From 71b84086c471b9eea934391c3f21399de83a0cdb Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Wed, 13 Aug 2014 17:35:38 -0700 Subject: [PATCH 122/231] [SPARK-2994][SQL] Support for udfs that take complex types Author: Michael Armbrust Closes #1915 from marmbrus/arrayUDF and squashes the following commits: a1c503d [Michael Armbrust] Support for udfs that take complex types (cherry picked from commit 9256d4a9c8c9ddb9ae6bbe3c3b99b03fb66b946b) Signed-off-by: Michael Armbrust --- .../spark/sql/hive/HiveInspectors.scala | 14 ++++++- .../org/apache/spark/sql/hive/hiveUdfs.scala | 41 +++++++++++-------- 2 files changed, 37 insertions(+), 18 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index 354fcd53f303b..943bbaa8ce25e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -71,6 +71,9 @@ private[hive] trait HiveInspectors { case c: Class[_] if c == java.lang.Boolean.TYPE => BooleanType case c: Class[_] if c.isArray => ArrayType(javaClassToDataType(c.getComponentType)) + + // Hive seems to return this for struct types? + case c: Class[_] if c == classOf[java.lang.Object] => NullType } /** Converts hive types to native catalyst types. */ @@ -147,7 +150,10 @@ private[hive] trait HiveInspectors { case t: java.sql.Timestamp => t case s: Seq[_] => seqAsJavaList(s.map(wrap)) case m: Map[_,_] => - mapAsJavaMap(m.map { case (k, v) => wrap(k) -> wrap(v) }) + // Some UDFs seem to assume we pass in a HashMap. + val hashMap = new java.util.HashMap[AnyRef, AnyRef]() + hashMap.putAll(m.map { case (k, v) => wrap(k) -> wrap(v) }) + hashMap case null => null } @@ -214,6 +220,12 @@ private[hive] trait HiveInspectors { import TypeInfoFactory._ def toTypeInfo: TypeInfo = dt match { + case ArrayType(elemType, _) => + getListTypeInfo(elemType.toTypeInfo) + case StructType(fields) => + getStructTypeInfo(fields.map(_.name), fields.map(_.dataType.toTypeInfo)) + case MapType(keyType, valueType, _) => + getMapTypeInfo(keyType.toTypeInfo, valueType.toTypeInfo) case BinaryType => binaryTypeInfo case BooleanType => booleanTypeInfo case ByteType => byteTypeInfo diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala index 179aac5cbd5cd..c6497a15efa0c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala @@ -55,7 +55,10 @@ private[hive] abstract class HiveFunctionRegistry HiveSimpleUdf( functionClassName, - children.zip(expectedDataTypes).map { case (e, t) => Cast(e, t) } + children.zip(expectedDataTypes).map { + case (e, NullType) => e + case (e, t) => Cast(e, t) + } ) } else if (classOf[GenericUDF].isAssignableFrom(functionInfo.getFunctionClass)) { HiveGenericUdf(functionClassName, children) @@ -115,22 +118,26 @@ private[hive] case class HiveSimpleUdf(functionClassName: String, children: Seq[ c.getParameterTypes.size == 1 && primitiveClasses.contains(c.getParameterTypes.head) } - val constructor = matchingConstructor.getOrElse( - sys.error(s"No matching wrapper found, options: ${argClass.getConstructors.toSeq}.")) - - (a: Any) => { - logDebug( - s"Wrapping $a of type ${if (a == null) "null" else a.getClass.getName} using $constructor.") - // We must make sure that primitives get boxed java style. - if (a == null) { - null - } else { - constructor.newInstance(a match { - case i: Int => i: java.lang.Integer - case bd: BigDecimal => new HiveDecimal(bd.underlying()) - case other: AnyRef => other - }).asInstanceOf[AnyRef] - } + matchingConstructor match { + case Some(constructor) => + (a: Any) => { + logDebug( + s"Wrapping $a of type ${if (a == null) "null" else a.getClass.getName} $constructor.") + // We must make sure that primitives get boxed java style. + if (a == null) { + null + } else { + constructor.newInstance(a match { + case i: Int => i: java.lang.Integer + case bd: BigDecimal => new HiveDecimal(bd.underlying()) + case other: AnyRef => other + }).asInstanceOf[AnyRef] + } + } + case None => + (a: Any) => a match { + case wrapper => wrap(wrapper) + } } } From ee7d2cc1a935da62de968799c0ecc6f98e43361a Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Wed, 13 Aug 2014 17:37:55 -0700 Subject: [PATCH 123/231] [SPARK-2650][SQL] More precise initial buffer size estimation for in-memory column buffer This is a follow up of #1880. Since the row number within a single batch is known, we can estimate a much more precise initial buffer size when building an in-memory column buffer. Author: Cheng Lian Closes #1901 from liancheng/precise-init-buffer-size and squashes the following commits: d5501fa [Cheng Lian] More precise initial buffer size estimation for in-memory column buffer (cherry picked from commit 376a82e196e102ef49b9722e8be0b01ac5890a8b) Signed-off-by: Michael Armbrust --- .../sql/columnar/InMemoryColumnarTableScan.scala | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala index 3364d0e18bcc9..e63b4903041f6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala @@ -20,12 +20,11 @@ package org.apache.spark.sql.columnar import java.nio.ByteBuffer import org.apache.spark.rdd.RDD +import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation -import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, Attribute} +import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericMutableRow} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.execution.{SparkPlan, LeafNode} -import org.apache.spark.sql.Row -import org.apache.spark.SparkConf +import org.apache.spark.sql.execution.{LeafNode, SparkPlan} object InMemoryRelation { def apply(useCompression: Boolean, batchSize: Int, child: SparkPlan): InMemoryRelation = @@ -48,7 +47,9 @@ private[sql] case class InMemoryRelation( new Iterator[Array[ByteBuffer]] { def next() = { val columnBuilders = output.map { attribute => - ColumnBuilder(ColumnType(attribute.dataType).typeId, 0, attribute.name, useCompression) + val columnType = ColumnType(attribute.dataType) + val initialBufferSize = columnType.defaultSize * batchSize + ColumnBuilder(columnType.typeId, initialBufferSize, attribute.name, useCompression) }.toArray var row: Row = null From e8e7f17e1e6d84268421dbfa315850b07a8a4c15 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Wed, 13 Aug 2014 17:40:59 -0700 Subject: [PATCH 124/231] [SPARK-2935][SQL]Fix parquet predicate push down bug Author: Michael Armbrust Closes #1863 from marmbrus/parquetPredicates and squashes the following commits: 10ad202 [Michael Armbrust] left <=> right f249158 [Michael Armbrust] quiet parquet tests. 802da5b [Michael Armbrust] Add test case. eab2eda [Michael Armbrust] Fix parquet predicate push down bug (cherry picked from commit 9fde1ff5fc114b5edb755ed40944607419b62184) Signed-off-by: Michael Armbrust --- .../scala/org/apache/spark/sql/parquet/ParquetFilters.scala | 5 +++-- sql/core/src/test/resources/log4j.properties | 3 +++ .../org/apache/spark/sql/parquet/ParquetQuerySuite.scala | 5 ++++- 3 files changed, 10 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala index cc575bedd8fcb..2298a9b933df5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala @@ -201,8 +201,9 @@ object ParquetFilters { (leftFilter, rightFilter) match { case (None, Some(filter)) => Some(filter) case (Some(filter), None) => Some(filter) - case (_, _) => - Some(new AndFilter(leftFilter.get, rightFilter.get)) + case (Some(leftF), Some(rightF)) => + Some(new AndFilter(leftF, rightF)) + case _ => None } } case p @ EqualTo(left: Literal, right: NamedExpression) if !right.nullable => diff --git a/sql/core/src/test/resources/log4j.properties b/sql/core/src/test/resources/log4j.properties index dffd15a61838b..c7e0ff1cf6494 100644 --- a/sql/core/src/test/resources/log4j.properties +++ b/sql/core/src/test/resources/log4j.properties @@ -36,6 +36,9 @@ log4j.appender.FA.layout.ConversionPattern=%d{HH:mm:ss.SSS} %p %c{1}: %m%n log4j.appender.FA.Threshold = INFO # Some packages are noisy for no good reason. +log4j.additivity.parquet.hadoop.ParquetRecordReader=false +log4j.logger.parquet.hadoop.ParquetRecordReader=OFF + log4j.additivity.org.apache.hadoop.hive.serde2.lazy.LazyStruct=false log4j.logger.org.apache.hadoop.hive.serde2.lazy.LazyStruct=OFF diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala index 9933575038bd3..502f6702e394e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala @@ -381,11 +381,14 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA val predicate5 = new GreaterThan(attribute1, attribute2) val badfilter = ParquetFilters.createFilter(predicate5) assert(badfilter.isDefined === false) + + val predicate6 = And(GreaterThan(attribute1, attribute2), GreaterThan(attribute1, attribute2)) + val badfilter2 = ParquetFilters.createFilter(predicate6) + assert(badfilter2.isDefined === false) } test("test filter by predicate pushdown") { for(myval <- Seq("myint", "mylong", "mydouble", "myfloat")) { - println(s"testing field $myval") val query1 = sql(s"SELECT * FROM testfiltersource WHERE $myval < 150 AND $myval >= 100") assert( query1.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], From b5b632c8cd02fd1e65ebd22216d20ec76715fc5d Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Wed, 13 Aug 2014 17:42:38 -0700 Subject: [PATCH 125/231] [SPARK-2970] [SQL] spark-sql script ends with IOException when EventLogging is enabled Author: Kousuke Saruta Closes #1891 from sarutak/SPARK-2970 and squashes the following commits: 4a2d2fe [Kousuke Saruta] Modified comment style 8bd833c [Kousuke Saruta] Modified style 6c0997c [Kousuke Saruta] Modified the timing of shutdown hook execution. It should be executed before shutdown hook of o.a.h.f.FileSystem (cherry picked from commit 905dc4b405e679feb145f5e6b35e952db2442e0d) Signed-off-by: Michael Armbrust --- .../sql/hive/thriftserver/SparkSQLCLIDriver.scala | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala index 4d0c506c5a397..4ed0f58ebc531 100755 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala @@ -26,6 +26,8 @@ import jline.{ConsoleReader, History} import org.apache.commons.lang.StringUtils import org.apache.commons.logging.LogFactory import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.FileSystem +import org.apache.hadoop.util.ShutdownHookManager import org.apache.hadoop.hive.cli.{CliDriver, CliSessionState, OptionsProcessor} import org.apache.hadoop.hive.common.LogUtils.LogInitializationException import org.apache.hadoop.hive.common.{HiveInterruptCallback, HiveInterruptUtils, LogUtils} @@ -116,13 +118,17 @@ private[hive] object SparkSQLCLIDriver { SessionState.start(sessionState) // Clean up after we exit - Runtime.getRuntime.addShutdownHook( + /** + * This should be executed before shutdown hook of + * FileSystem to avoid race condition of FileSystem operation + */ + ShutdownHookManager.get.addShutdownHook( new Thread() { override def run() { SparkSQLEnv.stop() } } - ) + , FileSystem.SHUTDOWN_HOOK_PRIORITY - 1) // "-h" option has been passed, so connect to Hive thrift server. if (sessionState.getHost != null) { From a8d2649719b3d8fdb1eed29ef179a6a896b3e37a Mon Sep 17 00:00:00 2001 From: guowei Date: Wed, 13 Aug 2014 17:45:24 -0700 Subject: [PATCH 126/231] [SPARK-2986] [SQL] fixed: setting properties does not effect it seems that set command does not run by SparkSQLDriver. it runs on hive api. user can not change reduce number by setting spark.sql.shuffle.partitions but i think setting hive properties seems just a role to spark sql. Author: guowei Closes #1904 from guowei2/temp-branch and squashes the following commits: 7d47dde [guowei] fixed: setting properties like spark.sql.shuffle.partitions does not effective (cherry picked from commit 63d6777737ca8559d4344d1661500b8ad868bb47) Signed-off-by: Michael Armbrust --- .../spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala index 4ed0f58ebc531..c16a7d3661c66 100755 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala @@ -34,7 +34,7 @@ import org.apache.hadoop.hive.common.{HiveInterruptCallback, HiveInterruptUtils, import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.ql.Driver import org.apache.hadoop.hive.ql.exec.Utilities -import org.apache.hadoop.hive.ql.processors.{CommandProcessor, CommandProcessorFactory} +import org.apache.hadoop.hive.ql.processors.{SetProcessor, CommandProcessor, CommandProcessorFactory} import org.apache.hadoop.hive.ql.session.SessionState import org.apache.hadoop.hive.shims.ShimLoader import org.apache.thrift.transport.TSocket @@ -284,7 +284,7 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { val proc: CommandProcessor = CommandProcessorFactory.get(tokens(0), hconf) if (proc != null) { - if (proc.isInstanceOf[Driver]) { + if (proc.isInstanceOf[Driver] || proc.isInstanceOf[SetProcessor]) { val driver = new SparkSQLDriver driver.init() From c6cb55a784ba8f9e5c4e7aadcc3ec9dce24f49ee Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Wed, 13 Aug 2014 18:08:38 -0700 Subject: [PATCH 127/231] SPARK-3020: Print completed indices rather than tasks in web UI Author: Patrick Wendell Closes #1933 from pwendell/speculation and squashes the following commits: 33a3473 [Patrick Wendell] Use OpenHashSet 8ce2ff0 [Patrick Wendell] SPARK-3020: Print completed indices rather than tasks in web UI (cherry picked from commit 0c7b452904fe6b5a966a66b956369123d8a9dd4b) Signed-off-by: Reynold Xin --- .../scala/org/apache/spark/ui/jobs/JobProgressListener.scala | 1 + core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala | 2 +- core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala | 2 ++ 3 files changed, 4 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala index a57a354620163..a3e9566832d06 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala @@ -153,6 +153,7 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { val (errorMessage, metrics): (Option[String], Option[TaskMetrics]) = taskEnd.reason match { case org.apache.spark.Success => + stageData.completedIndices.add(info.index) stageData.numCompleteTasks += 1 (None, Option(taskEnd.taskMetrics)) case e: ExceptionFailure => // Handle ExceptionFailure because we might have metrics diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala index 3dcfaf76e4aba..15998404ed612 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala @@ -168,7 +168,7 @@ private[ui] class StageTableBase( diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala index 85db15472a00c..a336bf7e1ed02 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala @@ -19,6 +19,7 @@ package org.apache.spark.ui.jobs import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler.{AccumulableInfo, TaskInfo} +import org.apache.spark.util.collection.OpenHashSet import scala.collection.mutable.HashMap @@ -38,6 +39,7 @@ private[jobs] object UIData { class StageUIData { var numActiveTasks: Int = _ var numCompleteTasks: Int = _ + var completedIndices = new OpenHashSet[Int]() var numFailedTasks: Int = _ var executorRunTime: Long = _ From dcd99c3e63f8a5154f904ae57e945e8caaade649 Mon Sep 17 00:00:00 2001 From: Masayoshi TSUZUKI Date: Wed, 13 Aug 2014 22:17:07 -0700 Subject: [PATCH 128/231] [SPARK-3006] Failed to execute spark-shell in Windows OS Modified the order of the options and arguments in spark-shell.cmd Author: Masayoshi TSUZUKI Closes #1918 from tsudukim/feature/SPARK-3006 and squashes the following commits: 8bba494 [Masayoshi TSUZUKI] [SPARK-3006] Failed to execute spark-shell in Windows OS 1a32410 [Masayoshi TSUZUKI] [SPARK-3006] Failed to execute spark-shell in Windows OS (cherry picked from commit 9497b12d429cf9d075807896637e40e205175203) Signed-off-by: Andrew Or --- bin/spark-shell.cmd | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bin/spark-shell.cmd b/bin/spark-shell.cmd index b56d69801171c..2ee60b4e2a2b3 100755 --- a/bin/spark-shell.cmd +++ b/bin/spark-shell.cmd @@ -19,4 +19,4 @@ rem set SPARK_HOME=%~dp0.. -cmd /V /E /C %SPARK_HOME%\bin\spark-submit.cmd spark-shell --class org.apache.spark.repl.Main %* +cmd /V /E /C %SPARK_HOME%\bin\spark-submit.cmd --class org.apache.spark.repl.Main %* spark-shell From bf7c6e198822d155c23cfaa7219c36e5db8d1eeb Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Wed, 13 Aug 2014 23:24:23 -0700 Subject: [PATCH 129/231] [Docs] Add missing tags (minor) These configs looked inconsistent from the rest. Author: Andrew Or Closes #1936 from andrewor14/docs-code and squashes the following commits: 15f578a [Andrew Or] Add tag (cherry picked from commit e4245656438d00714ebd59e89c4de3fdaae83494) Signed-off-by: Reynold Xin --- docs/configuration.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/configuration.md b/docs/configuration.md index 8136bd62ab6af..c8336b39133de 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -562,7 +562,7 @@ Apart from these, the following properties are also available, and may be useful - + - + + + + + +
    spark.io.compression.codecorg.apache.spark.io.
    SnappyCompressionCodec
    snappy - The codec used to compress internal data such as RDD partitions and shuffle outputs. - By default, Spark provides three codecs: org.apache.spark.io.LZ4CompressionCodec, + The codec used to compress internal data such as RDD partitions and shuffle outputs. By default, + Spark provides three codecs: lz4, lzf, and snappy. You + can also use fully qualified class names to specify the codec, e.g. + org.apache.spark.io.LZ4CompressionCodec, org.apache.spark.io.LZFCompressionCodec, and org.apache.spark.io.SnappyCompressionCodec. {submissionTime} {formattedDuration} - {makeProgressBar(stageData.numActiveTasks, stageData.numCompleteTasks, + {makeProgressBar(stageData.numActiveTasks, stageData.completedIndices.size, stageData.numFailedTasks, s.numTasks)} {inputReadWithUnit}
    spark.hadoop.validateOutputSpecsspark.hadoop.validateOutputSpecs true If set to true, validates the output specification (e.g. checking if the output directory already exists) used in saveAsHadoopFile and other variants. This can be disabled to silence exceptions due to pre-existing @@ -570,7 +570,7 @@ Apart from these, the following properties are also available, and may be useful previous versions of Spark. Simply use Hadoop's FileSystem API to delete output directories by hand.
    spark.executor.heartbeatIntervalspark.executor.heartbeatInterval 10000 Interval (milliseconds) between each executor's heartbeats to the driver. Heartbeats let the driver know that the executor is still alive and update it with metrics for in-progress From 1baf06f4e6a2c4767ad6107559396c7680085235 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Wed, 13 Aug 2014 23:53:44 -0700 Subject: [PATCH 130/231] [SPARK-2995][MLLIB] add ALS.setIntermediateRDDStorageLevel As mentioned in SPARK-2465, using `MEMORY_AND_DISK_SER` for user/product in/out links together with `spark.rdd.compress=true` can help reduce the space requirement by a lot, at the cost of speed. It might be useful to add this option so people can run ALS on much bigger datasets. Another option for the method name is `setIntermediateRDDStorageLevel`. Author: Xiangrui Meng Closes #1913 from mengxr/als-storagelevel and squashes the following commits: d942017 [Xiangrui Meng] rename to setIntermediateRDDStorageLevel 7550029 [Xiangrui Meng] add ALS.setIntermediateDataStorageLevel (cherry picked from commit 69a57a18ee35af1cc5a00b67a80837ea317cd330) Signed-off-by: Xiangrui Meng --- .../spark/mllib/recommendation/ALS.scala | 45 ++++++++++++------- 1 file changed, 30 insertions(+), 15 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala index 8ebc7e27ed4dd..84d192db53e26 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala @@ -111,11 +111,17 @@ class ALS private ( */ def this() = this(-1, -1, 10, 10, 0.01, false, 1.0) + /** If true, do alternating nonnegative least squares. */ + private var nonnegative = false + + /** storage level for user/product in/out links */ + private var intermediateRDDStorageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK + /** * Set the number of blocks for both user blocks and product blocks to parallelize the computation * into; pass -1 for an auto-configured number of blocks. Default: -1. */ - def setBlocks(numBlocks: Int): ALS = { + def setBlocks(numBlocks: Int): this.type = { this.numUserBlocks = numBlocks this.numProductBlocks = numBlocks this @@ -124,7 +130,7 @@ class ALS private ( /** * Set the number of user blocks to parallelize the computation. */ - def setUserBlocks(numUserBlocks: Int): ALS = { + def setUserBlocks(numUserBlocks: Int): this.type = { this.numUserBlocks = numUserBlocks this } @@ -132,31 +138,31 @@ class ALS private ( /** * Set the number of product blocks to parallelize the computation. */ - def setProductBlocks(numProductBlocks: Int): ALS = { + def setProductBlocks(numProductBlocks: Int): this.type = { this.numProductBlocks = numProductBlocks this } /** Set the rank of the feature matrices computed (number of features). Default: 10. */ - def setRank(rank: Int): ALS = { + def setRank(rank: Int): this.type = { this.rank = rank this } /** Set the number of iterations to run. Default: 10. */ - def setIterations(iterations: Int): ALS = { + def setIterations(iterations: Int): this.type = { this.iterations = iterations this } /** Set the regularization parameter, lambda. Default: 0.01. */ - def setLambda(lambda: Double): ALS = { + def setLambda(lambda: Double): this.type = { this.lambda = lambda this } /** Sets whether to use implicit preference. Default: false. */ - def setImplicitPrefs(implicitPrefs: Boolean): ALS = { + def setImplicitPrefs(implicitPrefs: Boolean): this.type = { this.implicitPrefs = implicitPrefs this } @@ -166,29 +172,38 @@ class ALS private ( * Sets the constant used in computing confidence in implicit ALS. Default: 1.0. */ @Experimental - def setAlpha(alpha: Double): ALS = { + def setAlpha(alpha: Double): this.type = { this.alpha = alpha this } /** Sets a random seed to have deterministic results. */ - def setSeed(seed: Long): ALS = { + def setSeed(seed: Long): this.type = { this.seed = seed this } - /** If true, do alternating nonnegative least squares. */ - private var nonnegative = false - /** * Set whether the least-squares problems solved at each iteration should have * nonnegativity constraints. */ - def setNonnegative(b: Boolean): ALS = { + def setNonnegative(b: Boolean): this.type = { this.nonnegative = b this } + /** + * :: DeveloperApi :: + * Sets storage level for intermediate RDDs (user/product in/out links). The default value is + * `MEMORY_AND_DISK`. Users can change it to a serialized storage, e.g., `MEMORY_AND_DISK_SER` and + * set `spark.rdd.compress` to `true` to reduce the space requirement, at the cost of speed. + */ + @DeveloperApi + def setIntermediateRDDStorageLevel(storageLevel: StorageLevel): this.type = { + this.intermediateRDDStorageLevel = storageLevel + this + } + /** * Run ALS with the configured parameters on an input RDD of (user, product, rating) triples. * Returns a MatrixFactorizationModel with feature vectors for each user and product. @@ -441,8 +456,8 @@ class ALS private ( }, preservesPartitioning = true) val inLinks = links.mapValues(_._1) val outLinks = links.mapValues(_._2) - inLinks.persist(StorageLevel.MEMORY_AND_DISK) - outLinks.persist(StorageLevel.MEMORY_AND_DISK) + inLinks.persist(intermediateRDDStorageLevel) + outLinks.persist(intermediateRDDStorageLevel) (inLinks, outLinks) } From 0cb2b82e0ef903dd99c589928bc17650037f25c5 Mon Sep 17 00:00:00 2001 From: Aaron Davidson Date: Thu, 14 Aug 2014 01:37:38 -0700 Subject: [PATCH 131/231] [SPARK-3029] Disable local execution of Spark jobs by default Currently, local execution of Spark jobs is only used by take(), and it can be problematic as it can load a significant amount of data onto the driver. The worst case scenarios occur if the RDD is cached (guaranteed to load whole partition), has very large elements, or the partition is just large and we apply a filter with high selectivity or computational overhead. Additionally, jobs that run locally in this manner do not show up in the web UI, and are thus harder to track or understand what is occurring. This PR adds a flag to disable local execution, which is turned OFF by default, with the intention of perhaps eventually removing this functionality altogether. Removing it now is a tougher proposition since it is part of the public runJob API. An alternative solution would be to limit the flag to take()/first() to avoid impacting any external users of this API, but such usage (or, at least, reliance upon the feature) is hopefully minimal. Author: Aaron Davidson Closes #1321 from aarondav/allowlocal and squashes the following commits: 136b253 [Aaron Davidson] Fix DAGSchedulerSuite 5599d55 [Aaron Davidson] [RFC] Disable local execution of Spark jobs by default (cherry picked from commit d069c5d9d2f6ce06389ca2ddf0b3ae4db72c5797) Signed-off-by: Reynold Xin --- .../scala/org/apache/spark/scheduler/DAGScheduler.scala | 7 ++++++- .../org/apache/spark/scheduler/DAGSchedulerSuite.scala | 4 +++- docs/configuration.md | 9 +++++++++ 3 files changed, 18 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 430e45ada5808..36bbaaa3f1c85 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -121,6 +121,9 @@ class DAGScheduler( private[scheduler] var eventProcessActor: ActorRef = _ + /** If enabled, we may run certain actions like take() and first() locally. */ + private val localExecutionEnabled = sc.getConf.getBoolean("spark.localExecution.enabled", false) + private def initializeEventProcessActor() { // blocking the thread until supervisor is started, which ensures eventProcessActor is // not null before any job is submitted @@ -732,7 +735,9 @@ class DAGScheduler( logInfo("Final stage: " + finalStage + "(" + finalStage.name + ")") logInfo("Parents of final stage: " + finalStage.parents) logInfo("Missing parents: " + getMissingParentStages(finalStage)) - if (allowLocal && finalStage.parents.size == 0 && partitions.length == 1) { + val shouldRunLocally = + localExecutionEnabled && allowLocal && finalStage.parents.isEmpty && partitions.length == 1 + if (shouldRunLocally) { // Compute very short actions like first() or take() with no parent stages locally. listenerBus.post(SparkListenerJobStart(job.jobId, Array[Int](), properties)) runLocally(job) diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 8c1b0fed11f72..bd829752eb401 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -141,7 +141,9 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F } before { - sc = new SparkContext("local", "DAGSchedulerSuite") + // Enable local execution for this test + val conf = new SparkConf().set("spark.localExecution.enabled", "true") + sc = new SparkContext("local", "DAGSchedulerSuite", conf) sparkListener.successfulStages.clear() sparkListener.failedStages.clear() failure = null diff --git a/docs/configuration.md b/docs/configuration.md index c8336b39133de..c408c468dcd94 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -846,6 +846,15 @@ Apart from these, the following properties are also available, and may be useful (in milliseconds).
    spark.localExecution.enabledfalse + Enables Spark to run certain jobs, such as first() or take() on the driver, without sending + tasks to the cluster. This can make certain jobs execute very quickly, but may require + shipping a whole partition of data to the driver. +
    #### Security From af809de77b5f939320c20d98d6c6dd98fcfd55a7 Mon Sep 17 00:00:00 2001 From: Graham Dennis Date: Thu, 14 Aug 2014 02:24:18 -0700 Subject: [PATCH 132/231] SPARK-2893: Do not swallow Exceptions when running a custom kryo registrator The previous behaviour of swallowing ClassNotFound exceptions when running a custom Kryo registrator could lead to difficult to debug problems later on at serialisation / deserialisation time, see SPARK-2878. Instead it is better to fail fast. Added test case. Author: Graham Dennis Closes #1827 from GrahamDennis/feature/spark-2893 and squashes the following commits: fbe4cb6 [Graham Dennis] [SPARK-2878]: Update the test case to match the updated exception message 65e53c5 [Graham Dennis] [SPARK-2893]: Improve message when a spark.kryo.registrator fails. f480d85 [Graham Dennis] [SPARK-2893] Fix typo. b59d2c2 [Graham Dennis] SPARK-2893: Do not swallow Exceptions when running a custom spark.kryo.registrator (cherry picked from commit 6b8de0e36c7548046c3b8a57f2c8e7e788dde8cc) Signed-off-by: Reynold Xin --- .../org/apache/spark/serializer/KryoSerializer.scala | 11 ++++++----- .../apache/spark/serializer/KryoSerializerSuite.scala | 10 ++++++++++ 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index 407cb9db6ee9a..85944eabcfefc 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -79,15 +79,16 @@ class KryoSerializer(conf: SparkConf) kryo.register(classOf[HttpBroadcast[_]], new KryoJavaSerializer()) // Allow the user to register their own classes by setting spark.kryo.registrator - try { - for (regCls <- registrator) { - logDebug("Running user registrator: " + regCls) + for (regCls <- registrator) { + logDebug("Running user registrator: " + regCls) + try { val reg = Class.forName(regCls, true, classLoader).newInstance() .asInstanceOf[KryoRegistrator] reg.registerClasses(kryo) + } catch { + case e: Exception => + throw new SparkException(s"Failed to invoke $regCls", e) } - } catch { - case e: Exception => logError("Failed to run spark.kryo.registrator", e) } // Register Chill's classes; we do this after our ranges and the user's own classes to let diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala index 789b773bae316..3bf9efebb39d2 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala @@ -207,6 +207,16 @@ class KryoSerializerSuite extends FunSuite with SharedSparkContext { .fold(new ClassWithoutNoArgConstructor(10))((t1, t2) => new ClassWithoutNoArgConstructor(t1.x + t2.x)).x assert(10 + control.sum === result) } + + test("kryo with nonexistent custom registrator should fail") { + import org.apache.spark.{SparkConf, SparkException} + + val conf = new SparkConf(false) + conf.set("spark.kryo.registrator", "this.class.does.not.exist") + + val thrown = intercept[SparkException](new KryoSerializer(conf).newInstance()) + assert(thrown.getMessage.contains("Failed to invoke this.class.does.not.exist")) + } } class KryoSerializerResizableOutputSuite extends FunSuite { From 221c84e6ab631a137165e0e6b41d4d10b018d2b6 Mon Sep 17 00:00:00 2001 From: Chia-Yung Su Date: Thu, 14 Aug 2014 10:43:08 -0700 Subject: [PATCH 133/231] [SPARK-3011][SQL] _temporary directory should be filtered out by sqlContext.parquetFile Author: Chia-Yung Su Closes #1924 from joesu/bugfix-spark3011 and squashes the following commits: c7e44f2 [Chia-Yung Su] match syntax f8fc32a [Chia-Yung Su] filter out tmp dir (cherry picked from commit 078f3fbda860e2f5de34153c55dfc3fecb4256e9) Signed-off-by: Michael Armbrust --- .../main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala index 2867dc0a8b1f9..37091bcf73dd6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala @@ -375,7 +375,8 @@ private[parquet] object ParquetTypesConverter extends Logging { val children = fs.listStatus(path).filterNot { status => val name = status.getPath.getName - name(0) == '.' || name == FileOutputCommitter.SUCCEEDED_FILE_NAME + name(0) == '.' || name == FileOutputCommitter.SUCCEEDED_FILE_NAME || + name == FileOutputCommitter.TEMP_DIR_NAME } // NOTE (lian): Parquet "_metadata" file can be very slow if the file consists of lots of row From de501e169f24e4573747aec85b7651c98633c028 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Thu, 14 Aug 2014 10:46:33 -0700 Subject: [PATCH 134/231] [SPARK-2927][SQL] Add a conf to configure if we always read Binary columns stored in Parquet as String columns This PR adds a new conf flag `spark.sql.parquet.binaryAsString`. When it is `true`, if there is no parquet metadata file available to provide the schema of the data, we will always treat binary fields stored in parquet as string fields. This conf is used to provide a way to read string fields generated without UTF8 decoration. JIRA: https://issues.apache.org/jira/browse/SPARK-2927 Author: Yin Huai Closes #1855 from yhuai/parquetBinaryAsString and squashes the following commits: 689ffa9 [Yin Huai] Add missing "=". 80827de [Yin Huai] Unit test. 1765ca4 [Yin Huai] Use .toBoolean. 9d3f199 [Yin Huai] Merge remote-tracking branch 'upstream/master' into parquetBinaryAsString 5d436a1 [Yin Huai] The initial support of adding a conf to treat binary columns stored in Parquet as string columns. (cherry picked from commit add75d4831fdc35712bf8b737574ea0bc677c37c) Signed-off-by: Michael Armbrust --- .../scala/org/apache/spark/sql/SQLConf.scala | 10 +++- .../spark/sql/parquet/ParquetRelation.scala | 6 ++- .../sql/parquet/ParquetTableSupport.scala | 3 +- .../spark/sql/parquet/ParquetTypes.scala | 36 +++++++------ .../spark/sql/parquet/ParquetQuerySuite.scala | 54 +++++++++++++++++-- 5 files changed, 87 insertions(+), 22 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 35c51dec0bcf5..90de11182e605 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -31,6 +31,7 @@ private[spark] object SQLConf { val SHUFFLE_PARTITIONS = "spark.sql.shuffle.partitions" val CODEGEN_ENABLED = "spark.sql.codegen" val DIALECT = "spark.sql.dialect" + val PARQUET_BINARY_AS_STRING = "spark.sql.parquet.binaryAsString" object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" @@ -87,8 +88,7 @@ trait SQLConf { * * Defaults to false as this feature is currently experimental. */ - private[spark] def codegenEnabled: Boolean = - if (getConf(CODEGEN_ENABLED, "false") == "true") true else false + private[spark] def codegenEnabled: Boolean = getConf(CODEGEN_ENABLED, "false").toBoolean /** * Upper bound on the sizes (in bytes) of the tables qualified for the auto conversion to @@ -108,6 +108,12 @@ trait SQLConf { private[spark] def defaultSizeInBytes: Long = getConf(DEFAULT_SIZE_IN_BYTES, (autoBroadcastJoinThreshold + 1).toString).toLong + /** + * When set to true, we always treat byte arrays in Parquet files as strings. + */ + private[spark] def isParquetBinaryAsString: Boolean = + getConf(PARQUET_BINARY_AS_STRING, "false").toBoolean + /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala index b3bae5db0edbc..053b2a154389c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala @@ -60,7 +60,11 @@ private[sql] case class ParquetRelation( .getSchema /** Attributes */ - override val output = ParquetTypesConverter.readSchemaFromFile(new Path(path), conf) + override val output = + ParquetTypesConverter.readSchemaFromFile( + new Path(path), + conf, + sqlContext.isParquetBinaryAsString) override def newInstance = ParquetRelation(path, conf, sqlContext).asInstanceOf[this.type] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala index 6d4ce32ac5bfa..6a657c20fe46c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala @@ -80,9 +80,10 @@ private[parquet] class RowReadSupport extends ReadSupport[Row] with Logging { } } // if both unavailable, fall back to deducing the schema from the given Parquet schema + // TODO: Why it can be null? if (schema == null) { log.debug("falling back to Parquet read schema") - schema = ParquetTypesConverter.convertToAttributes(parquetSchema) + schema = ParquetTypesConverter.convertToAttributes(parquetSchema, false) } log.debug(s"list of attributes that will be read: $schema") new RowRecordMaterializer(parquetSchema, schema) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala index 37091bcf73dd6..b0579f76da073 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala @@ -43,10 +43,13 @@ private[parquet] object ParquetTypesConverter extends Logging { def isPrimitiveType(ctype: DataType): Boolean = classOf[PrimitiveType] isAssignableFrom ctype.getClass - def toPrimitiveDataType(parquetType: ParquetPrimitiveType): DataType = + def toPrimitiveDataType( + parquetType: ParquetPrimitiveType, + binayAsString: Boolean): DataType = parquetType.getPrimitiveTypeName match { case ParquetPrimitiveTypeName.BINARY - if parquetType.getOriginalType == ParquetOriginalType.UTF8 => StringType + if (parquetType.getOriginalType == ParquetOriginalType.UTF8 || + binayAsString) => StringType case ParquetPrimitiveTypeName.BINARY => BinaryType case ParquetPrimitiveTypeName.BOOLEAN => BooleanType case ParquetPrimitiveTypeName.DOUBLE => DoubleType @@ -85,7 +88,7 @@ private[parquet] object ParquetTypesConverter extends Logging { * @param parquetType The type to convert. * @return The corresponding Catalyst type. */ - def toDataType(parquetType: ParquetType): DataType = { + def toDataType(parquetType: ParquetType, isBinaryAsString: Boolean): DataType = { def correspondsToMap(groupType: ParquetGroupType): Boolean = { if (groupType.getFieldCount != 1 || groupType.getFields.apply(0).isPrimitive) { false @@ -107,7 +110,7 @@ private[parquet] object ParquetTypesConverter extends Logging { } if (parquetType.isPrimitive) { - toPrimitiveDataType(parquetType.asPrimitiveType) + toPrimitiveDataType(parquetType.asPrimitiveType, isBinaryAsString) } else { val groupType = parquetType.asGroupType() parquetType.getOriginalType match { @@ -116,7 +119,7 @@ private[parquet] object ParquetTypesConverter extends Logging { case ParquetOriginalType.LIST => { // TODO: check enums! assert(groupType.getFieldCount == 1) val field = groupType.getFields.apply(0) - ArrayType(toDataType(field), containsNull = false) + ArrayType(toDataType(field, isBinaryAsString), containsNull = false) } case ParquetOriginalType.MAP => { assert( @@ -126,9 +129,9 @@ private[parquet] object ParquetTypesConverter extends Logging { assert( keyValueGroup.getFieldCount == 2, "Parquet Map type malformatted: nested group should have 2 (key, value) fields!") - val keyType = toDataType(keyValueGroup.getFields.apply(0)) + val keyType = toDataType(keyValueGroup.getFields.apply(0), isBinaryAsString) assert(keyValueGroup.getFields.apply(0).getRepetition == Repetition.REQUIRED) - val valueType = toDataType(keyValueGroup.getFields.apply(1)) + val valueType = toDataType(keyValueGroup.getFields.apply(1), isBinaryAsString) assert(keyValueGroup.getFields.apply(1).getRepetition == Repetition.REQUIRED) // TODO: set valueContainsNull explicitly instead of assuming valueContainsNull is true // at here. @@ -138,22 +141,22 @@ private[parquet] object ParquetTypesConverter extends Logging { // Note: the order of these checks is important! if (correspondsToMap(groupType)) { // MapType val keyValueGroup = groupType.getFields.apply(0).asGroupType() - val keyType = toDataType(keyValueGroup.getFields.apply(0)) + val keyType = toDataType(keyValueGroup.getFields.apply(0), isBinaryAsString) assert(keyValueGroup.getFields.apply(0).getRepetition == Repetition.REQUIRED) - val valueType = toDataType(keyValueGroup.getFields.apply(1)) + val valueType = toDataType(keyValueGroup.getFields.apply(1), isBinaryAsString) assert(keyValueGroup.getFields.apply(1).getRepetition == Repetition.REQUIRED) // TODO: set valueContainsNull explicitly instead of assuming valueContainsNull is true // at here. MapType(keyType, valueType) } else if (correspondsToArray(groupType)) { // ArrayType - val elementType = toDataType(groupType.getFields.apply(0)) + val elementType = toDataType(groupType.getFields.apply(0), isBinaryAsString) ArrayType(elementType, containsNull = false) } else { // everything else: StructType val fields = groupType .getFields .map(ptype => new StructField( ptype.getName, - toDataType(ptype), + toDataType(ptype, isBinaryAsString), ptype.getRepetition != Repetition.REQUIRED)) StructType(fields) } @@ -276,7 +279,7 @@ private[parquet] object ParquetTypesConverter extends Logging { } } - def convertToAttributes(parquetSchema: ParquetType): Seq[Attribute] = { + def convertToAttributes(parquetSchema: ParquetType, isBinaryAsString: Boolean): Seq[Attribute] = { parquetSchema .asGroupType() .getFields @@ -284,7 +287,7 @@ private[parquet] object ParquetTypesConverter extends Logging { field => new AttributeReference( field.getName, - toDataType(field), + toDataType(field, isBinaryAsString), field.getRepetition != Repetition.REQUIRED)()) } @@ -404,7 +407,10 @@ private[parquet] object ParquetTypesConverter extends Logging { * @param conf The Hadoop configuration to use. * @return A list of attributes that make up the schema. */ - def readSchemaFromFile(origPath: Path, conf: Option[Configuration]): Seq[Attribute] = { + def readSchemaFromFile( + origPath: Path, + conf: Option[Configuration], + isBinaryAsString: Boolean): Seq[Attribute] = { val keyValueMetadata: java.util.Map[String, String] = readMetaData(origPath, conf) .getFileMetaData @@ -413,7 +419,7 @@ private[parquet] object ParquetTypesConverter extends Logging { convertFromString(keyValueMetadata.get(RowReadSupport.SPARK_METADATA_KEY)) } else { val attributes = convertToAttributes( - readMetaData(origPath, conf).getFileMetaData.getSchema) + readMetaData(origPath, conf).getFileMetaData.getSchema, isBinaryAsString) log.info(s"Falling back to schema conversion from Parquet types; result: $attributes") attributes } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala index 502f6702e394e..172dcd6aa0ee3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala @@ -21,8 +21,6 @@ import org.scalatest.{BeforeAndAfterAll, FunSuiteLike} import parquet.hadoop.ParquetFileWriter import parquet.hadoop.util.ContextUtil -import parquet.schema.MessageTypeParser - import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.mapreduce.Job @@ -33,7 +31,6 @@ import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedAttribute} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.types.{BooleanType, IntegerType} import org.apache.spark.sql.catalyst.util.getTempFilePath -import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext._ import org.apache.spark.util.Utils @@ -138,6 +135,57 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA } } + test("Treat binary as string") { + val oldIsParquetBinaryAsString = TestSQLContext.isParquetBinaryAsString + + // Create the test file. + val file = getTempFilePath("parquet") + val path = file.toString + val range = (0 to 255) + val rowRDD = TestSQLContext.sparkContext.parallelize(range) + .map(i => org.apache.spark.sql.Row(i, s"val_$i".getBytes)) + // We need to ask Parquet to store the String column as a Binary column. + val schema = StructType( + StructField("c1", IntegerType, false) :: + StructField("c2", BinaryType, false) :: Nil) + val schemaRDD1 = applySchema(rowRDD, schema) + schemaRDD1.saveAsParquetFile(path) + val resultWithBinary = parquetFile(path).collect + range.foreach { + i => + assert(resultWithBinary(i).getInt(0) === i) + assert(resultWithBinary(i)(1) === s"val_$i".getBytes) + } + + TestSQLContext.setConf(SQLConf.PARQUET_BINARY_AS_STRING, "true") + // This ParquetRelation always use Parquet types to derive output. + val parquetRelation = new ParquetRelation( + path.toString, + Some(TestSQLContext.sparkContext.hadoopConfiguration), + TestSQLContext) { + override val output = + ParquetTypesConverter.convertToAttributes( + ParquetTypesConverter.readMetaData(new Path(path), conf).getFileMetaData.getSchema, + TestSQLContext.isParquetBinaryAsString) + } + val schemaRDD = new SchemaRDD(TestSQLContext, parquetRelation) + val resultWithString = schemaRDD.collect + range.foreach { + i => + assert(resultWithString(i).getInt(0) === i) + assert(resultWithString(i)(1) === s"val_$i") + } + + schemaRDD.registerTempTable("tmp") + checkAnswer( + sql("SELECT c1, c2 FROM tmp WHERE c2 = 'val_5' OR c2 = 'val_7'"), + (5, "val_5") :: + (7, "val_7") :: Nil) + + // Set it back. + TestSQLContext.setConf(SQLConf.PARQUET_BINARY_AS_STRING, oldIsParquetBinaryAsString.toString) + } + test("Read/Write All Types with non-primitive type") { val tempDir = getTempFilePath("parquetTest").getCanonicalPath val range = (0 to 255) From 850abaa36043104e5f09bf2754d1ae3f9ce86e3d Mon Sep 17 00:00:00 2001 From: Ahir Reddy Date: Thu, 14 Aug 2014 10:48:52 -0700 Subject: [PATCH 135/231] [SQL] Python JsonRDD UTF8 Encoding Fix Only encode unicode objects to UTF-8, and not strings Author: Ahir Reddy Closes #1914 from ahirreddy/json-rdd-unicode-fix1 and squashes the following commits: ca4e9ba [Ahir Reddy] Encoding Fix (cherry picked from commit fde692b361773110c262abe219e7c8128bd76419) Signed-off-by: Michael Armbrust --- python/pyspark/sql.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 46540ca3f1e8a..95086a2258222 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -1267,7 +1267,9 @@ def func(iterator): for x in iterator: if not isinstance(x, basestring): x = unicode(x) - yield x.encode("utf-8") + if isinstance(x, unicode): + x = x.encode("utf-8") + yield x keyed = rdd.mapPartitions(func) keyed._bypass_serializer = True jrdd = keyed._jrdd.map(self._jvm.BytesToString()) From df25acdf447bfac9c41440f49bd3bbe1c5d34696 Mon Sep 17 00:00:00 2001 From: wangfei Date: Thu, 14 Aug 2014 10:55:51 -0700 Subject: [PATCH 136/231] [SPARK-2925] [sql]fix spark-sql and start-thriftserver shell bugs when set --driver-java-options https://issues.apache.org/jira/browse/SPARK-2925 Run cmd like this will get the error bin/spark-sql --driver-java-options '-Xdebug -Xnoagent -Xrunjdwp:transport=dt_socket,address=8788,server=y,suspend=y' Error: Unrecognized option '-Xnoagent'. Run with --help for usage help or --verbose for debug output Author: wangfei Author: wangfei Closes #1851 from scwf/patch-2 and squashes the following commits: 516554d [wangfei] quote variables to fix this issue 8bd40f2 [wangfei] quote variables to fix this problem e6d79e3 [wangfei] fix start-thriftserver bug when set driver-java-options 948395d [wangfei] fix spark-sql error when set --driver-java-options (cherry picked from commit 267fdffe2743bc2dc706c8ac8af0ae33a358a5d3) Signed-off-by: Michael Armbrust --- bin/spark-sql | 18 +++++++++--------- sbin/start-thriftserver.sh | 8 ++++---- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/bin/spark-sql b/bin/spark-sql index 7813ccc361415..564f1f419060f 100755 --- a/bin/spark-sql +++ b/bin/spark-sql @@ -65,30 +65,30 @@ while (($#)); do case $1 in -d | --define | --database | -f | -h | --hiveconf | --hivevar | -i | -p) ensure_arg_number $# 2 - CLI_ARGS+=($1); shift - CLI_ARGS+=($1); shift + CLI_ARGS+=("$1"); shift + CLI_ARGS+=("$1"); shift ;; -e) ensure_arg_number $# 2 - CLI_ARGS+=($1); shift - CLI_ARGS+=(\"$1\"); shift + CLI_ARGS+=("$1"); shift + CLI_ARGS+=("$1"); shift ;; -s | --silent) - CLI_ARGS+=($1); shift + CLI_ARGS+=("$1"); shift ;; -v | --verbose) # Both SparkSubmit and SparkSQLCLIDriver recognizes -v | --verbose - CLI_ARGS+=($1) - SUBMISSION_ARGS+=($1); shift + CLI_ARGS+=("$1") + SUBMISSION_ARGS+=("$1"); shift ;; *) - SUBMISSION_ARGS+=($1); shift + SUBMISSION_ARGS+=("$1"); shift ;; esac done -eval exec "$FWDIR"/bin/spark-submit --class $CLASS ${SUBMISSION_ARGS[*]} spark-internal ${CLI_ARGS[*]} +exec "$FWDIR"/bin/spark-submit --class $CLASS "${SUBMISSION_ARGS[@]}" spark-internal "${CLI_ARGS[@]}" diff --git a/sbin/start-thriftserver.sh b/sbin/start-thriftserver.sh index 603f50ae13240..2c4452473ccbc 100755 --- a/sbin/start-thriftserver.sh +++ b/sbin/start-thriftserver.sh @@ -65,14 +65,14 @@ while (($#)); do case $1 in --hiveconf) ensure_arg_number $# 2 - THRIFT_SERVER_ARGS+=($1); shift - THRIFT_SERVER_ARGS+=($1); shift + THRIFT_SERVER_ARGS+=("$1"); shift + THRIFT_SERVER_ARGS+=("$1"); shift ;; *) - SUBMISSION_ARGS+=($1); shift + SUBMISSION_ARGS+=("$1"); shift ;; esac done -eval exec "$FWDIR"/bin/spark-submit --class $CLASS ${SUBMISSION_ARGS[*]} spark-internal ${THRIFT_SERVER_ARGS[*]} +exec "$FWDIR"/bin/spark-submit --class $CLASS "${SUBMISSION_ARGS[@]}" spark-internal "${THRIFT_SERVER_ARGS[@]}" From a3dc54fa11c5323ec191df52c06443d3f96956d4 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 14 Aug 2014 11:22:41 -0700 Subject: [PATCH 137/231] Minor cleanup of metrics.Source - Added override. - Marked some variables as private. Author: Reynold Xin Closes #1943 from rxin/metricsSource and squashes the following commits: fbfa943 [Reynold Xin] Minor cleanup of metrics.Source. - Added override. - Marked some variables as private. (cherry picked from commit eaeb0f76fa0f103c7db0f3975cb8562715410973) Signed-off-by: Reynold Xin --- .../spark/deploy/master/ApplicationSource.scala | 4 ++-- .../org/apache/spark/deploy/master/MasterSource.scala | 4 ++-- .../org/apache/spark/deploy/worker/WorkerSource.scala | 4 ++-- .../org/apache/spark/executor/ExecutorSource.scala | 5 +++-- .../org/apache/spark/metrics/source/JvmSource.scala | 11 ++++------- .../apache/spark/scheduler/DAGSchedulerSource.scala | 4 ++-- .../org/apache/spark/storage/BlockManagerSource.scala | 4 ++-- .../org/apache/spark/streaming/StreamingSource.scala | 6 +++--- 8 files changed, 20 insertions(+), 22 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationSource.scala b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationSource.scala index c87b66f047dc8..38db02cd2421b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationSource.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationSource.scala @@ -22,8 +22,8 @@ import com.codahale.metrics.{Gauge, MetricRegistry} import org.apache.spark.metrics.source.Source class ApplicationSource(val application: ApplicationInfo) extends Source { - val metricRegistry = new MetricRegistry() - val sourceName = "%s.%s.%s".format("application", application.desc.name, + override val metricRegistry = new MetricRegistry() + override val sourceName = "%s.%s.%s".format("application", application.desc.name, System.currentTimeMillis()) metricRegistry.register(MetricRegistry.name("status"), new Gauge[String] { diff --git a/core/src/main/scala/org/apache/spark/deploy/master/MasterSource.scala b/core/src/main/scala/org/apache/spark/deploy/master/MasterSource.scala index 36c1b87b7f684..9c3f79f1244b7 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/MasterSource.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/MasterSource.scala @@ -22,8 +22,8 @@ import com.codahale.metrics.{Gauge, MetricRegistry} import org.apache.spark.metrics.source.Source private[spark] class MasterSource(val master: Master) extends Source { - val metricRegistry = new MetricRegistry() - val sourceName = "master" + override val metricRegistry = new MetricRegistry() + override val sourceName = "master" // Gauge for worker numbers in cluster metricRegistry.register(MetricRegistry.name("workers"), new Gauge[Int] { diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerSource.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerSource.scala index b7ddd8c816cbc..df1e01b23b932 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerSource.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerSource.scala @@ -22,8 +22,8 @@ import com.codahale.metrics.{Gauge, MetricRegistry} import org.apache.spark.metrics.source.Source private[spark] class WorkerSource(val worker: Worker) extends Source { - val sourceName = "worker" - val metricRegistry = new MetricRegistry() + override val sourceName = "worker" + override val metricRegistry = new MetricRegistry() metricRegistry.register(MetricRegistry.name("executors"), new Gauge[Int] { override def getValue: Int = worker.executors.size diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala index 0ed52cfe9df61..d6721586566c2 100644 --- a/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala +++ b/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala @@ -35,9 +35,10 @@ private[spark] class ExecutorSource(val executor: Executor, executorId: String) }) } - val metricRegistry = new MetricRegistry() + override val metricRegistry = new MetricRegistry() + // TODO: It would be nice to pass the application name here - val sourceName = "executor.%s".format(executorId) + override val sourceName = "executor.%s".format(executorId) // Gauge for executor thread pool's actively executing task counts metricRegistry.register(MetricRegistry.name("threadpool", "activeTasks"), new Gauge[Int] { diff --git a/core/src/main/scala/org/apache/spark/metrics/source/JvmSource.scala b/core/src/main/scala/org/apache/spark/metrics/source/JvmSource.scala index f865f9648a91e..635bff2cd7ec8 100644 --- a/core/src/main/scala/org/apache/spark/metrics/source/JvmSource.scala +++ b/core/src/main/scala/org/apache/spark/metrics/source/JvmSource.scala @@ -21,12 +21,9 @@ import com.codahale.metrics.MetricRegistry import com.codahale.metrics.jvm.{GarbageCollectorMetricSet, MemoryUsageGaugeSet} private[spark] class JvmSource extends Source { - val sourceName = "jvm" - val metricRegistry = new MetricRegistry() + override val sourceName = "jvm" + override val metricRegistry = new MetricRegistry() - val gcMetricSet = new GarbageCollectorMetricSet - val memGaugeSet = new MemoryUsageGaugeSet - - metricRegistry.registerAll(gcMetricSet) - metricRegistry.registerAll(memGaugeSet) + metricRegistry.registerAll(new GarbageCollectorMetricSet) + metricRegistry.registerAll(new MemoryUsageGaugeSet) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala index 5878e733908f5..94944399b134a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala @@ -24,8 +24,8 @@ import org.apache.spark.metrics.source.Source private[spark] class DAGSchedulerSource(val dagScheduler: DAGScheduler, sc: SparkContext) extends Source { - val metricRegistry = new MetricRegistry() - val sourceName = "%s.DAGScheduler".format(sc.appName) + override val metricRegistry = new MetricRegistry() + override val sourceName = "%s.DAGScheduler".format(sc.appName) metricRegistry.register(MetricRegistry.name("stage", "failedStages"), new Gauge[Int] { override def getValue: Int = dagScheduler.failedStages.size diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSource.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSource.scala index 3f14c40ec61cb..49fea6d9e2a76 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSource.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSource.scala @@ -24,8 +24,8 @@ import org.apache.spark.metrics.source.Source private[spark] class BlockManagerSource(val blockManager: BlockManager, sc: SparkContext) extends Source { - val metricRegistry = new MetricRegistry() - val sourceName = "%s.BlockManager".format(sc.appName) + override val metricRegistry = new MetricRegistry() + override val sourceName = "%s.BlockManager".format(sc.appName) metricRegistry.register(MetricRegistry.name("memory", "maxMem_MB"), new Gauge[Long] { override def getValue: Long = { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingSource.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingSource.scala index 774adc3c23c21..75f0e8716dc7e 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingSource.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingSource.scala @@ -23,10 +23,10 @@ import org.apache.spark.metrics.source.Source import org.apache.spark.streaming.ui.StreamingJobProgressListener private[streaming] class StreamingSource(ssc: StreamingContext) extends Source { - val metricRegistry = new MetricRegistry - val sourceName = "%s.StreamingMetrics".format(ssc.sparkContext.appName) + override val metricRegistry = new MetricRegistry + override val sourceName = "%s.StreamingMetrics".format(ssc.sparkContext.appName) - val streamingListener = ssc.uiTab.listener + private val streamingListener = ssc.uiTab.listener private def registerGauge[T](name: String, f: StreamingJobProgressListener => T, defaultValue: T) { From dc8ef9387247e191406d8ff2df7af27bba007f53 Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Thu, 14 Aug 2014 11:56:13 -0700 Subject: [PATCH 138/231] [SPARK-2979][MLlib] Improve the convergence rate by minimizing the condition number MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit In theory, the scale of your inputs are irrelevant to logistic regression. You can "theoretically" multiply X1 by 1E6 and the estimate for β1 will adjust accordingly. It will be 1E-6 times smaller than the original β1, due to the invariance property of MLEs. However, during the optimization process, the convergence (rate) depends on the condition number of the training dataset. Scaling the variables often reduces this condition number, thus improving the convergence rate. Without reducing the condition number, some training datasets mixing the columns with different scales may not be able to converge. GLMNET and LIBSVM packages perform the scaling to reduce the condition number, and return the weights in the original scale. See page 9 in http://cran.r-project.org/web/packages/glmnet/glmnet.pdf Here, if useFeatureScaling is enabled, we will standardize the training features by dividing the variance of each column (without subtracting the mean to densify the sparse vector), and train the model in the scaled space. Then we transform the coefficients from the scaled space to the original scale as GLMNET and LIBSVM do. Currently, it's only enabled in LogisticRegressionWithLBFGS. Author: DB Tsai Closes #1897 from dbtsai/dbtsai-feature-scaling and squashes the following commits: f19fc02 [DB Tsai] Added more comments 1d85289 [DB Tsai] Improve the convergence rate by minimize the condition number in LOR with LBFGS (cherry picked from commit 96221067572e5955af1a7710b0cca33a73db4bd5) Signed-off-by: Xiangrui Meng --- .../classification/LogisticRegression.scala | 4 +- .../GeneralizedLinearAlgorithm.scala | 69 ++++++++++++++++++- .../LogisticRegressionSuite.scala | 57 +++++++++++++++ 3 files changed, 126 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala index 31d474a20fa85..6790c86f651b4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala @@ -62,7 +62,7 @@ class LogisticRegressionModel ( override protected def predictPoint(dataMatrix: Vector, weightMatrix: Vector, intercept: Double) = { val margin = weightMatrix.toBreeze.dot(dataMatrix.toBreeze) + intercept - val score = 1.0/ (1.0 + math.exp(-margin)) + val score = 1.0 / (1.0 + math.exp(-margin)) threshold match { case Some(t) => if (score < t) 0.0 else 1.0 case None => score @@ -204,6 +204,8 @@ class LogisticRegressionWithLBFGS private ( */ def this() = this(1E-4, 100, 0.0) + this.setFeatureScaling(true) + private val gradient = new LogisticGradient() private val updater = new SimpleUpdater() // Have to return new LBFGS object every time since users can reset the parameters anytime. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala index 54854252d7477..20c1fdd2269ce 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala @@ -18,6 +18,7 @@ package org.apache.spark.mllib.regression import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.mllib.feature.StandardScaler import org.apache.spark.{Logging, SparkException} import org.apache.spark.rdd.RDD import org.apache.spark.mllib.optimization._ @@ -94,6 +95,22 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] protected var validateData: Boolean = true + /** + * Whether to perform feature scaling before model training to reduce the condition numbers + * which can significantly help the optimizer converging faster. The scaling correction will be + * translated back to resulting model weights, so it's transparent to users. + * Note: This technique is used in both libsvm and glmnet packages. Default false. + */ + private var useFeatureScaling = false + + /** + * Set if the algorithm should use feature scaling to improve the convergence during optimization. + */ + private[mllib] def setFeatureScaling(useFeatureScaling: Boolean): this.type = { + this.useFeatureScaling = useFeatureScaling + this + } + /** * Create a model given the weights and intercept */ @@ -137,11 +154,45 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] throw new SparkException("Input validation failed.") } + /** + * Scaling columns to unit variance as a heuristic to reduce the condition number: + * + * During the optimization process, the convergence (rate) depends on the condition number of + * the training dataset. Scaling the variables often reduces this condition number + * heuristically, thus improving the convergence rate. Without reducing the condition number, + * some training datasets mixing the columns with different scales may not be able to converge. + * + * GLMNET and LIBSVM packages perform the scaling to reduce the condition number, and return + * the weights in the original scale. + * See page 9 in http://cran.r-project.org/web/packages/glmnet/glmnet.pdf + * + * Here, if useFeatureScaling is enabled, we will standardize the training features by dividing + * the variance of each column (without subtracting the mean), and train the model in the + * scaled space. Then we transform the coefficients from the scaled space to the original scale + * as GLMNET and LIBSVM do. + * + * Currently, it's only enabled in LogisticRegressionWithLBFGS + */ + val scaler = if (useFeatureScaling) { + (new StandardScaler).fit(input.map(x => x.features)) + } else { + null + } + // Prepend an extra variable consisting of all 1.0's for the intercept. val data = if (addIntercept) { - input.map(labeledPoint => (labeledPoint.label, appendBias(labeledPoint.features))) + if(useFeatureScaling) { + input.map(labeledPoint => + (labeledPoint.label, appendBias(scaler.transform(labeledPoint.features)))) + } else { + input.map(labeledPoint => (labeledPoint.label, appendBias(labeledPoint.features))) + } } else { - input.map(labeledPoint => (labeledPoint.label, labeledPoint.features)) + if (useFeatureScaling) { + input.map(labeledPoint => (labeledPoint.label, scaler.transform(labeledPoint.features))) + } else { + input.map(labeledPoint => (labeledPoint.label, labeledPoint.features)) + } } val initialWeightsWithIntercept = if (addIntercept) { @@ -153,13 +204,25 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] val weightsWithIntercept = optimizer.optimize(data, initialWeightsWithIntercept) val intercept = if (addIntercept) weightsWithIntercept(weightsWithIntercept.size - 1) else 0.0 - val weights = + var weights = if (addIntercept) { Vectors.dense(weightsWithIntercept.toArray.slice(0, weightsWithIntercept.size - 1)) } else { weightsWithIntercept } + /** + * The weights and intercept are trained in the scaled space; we're converting them back to + * the original scale. + * + * Math shows that if we only perform standardization without subtracting means, the intercept + * will not be changed. w_i = w_i' / v_i where w_i' is the coefficient in the scaled space, w_i + * is the coefficient in the original space, and v_i is the variance of the column i. + */ + if (useFeatureScaling) { + weights = scaler.transform(weights) + } + createModel(weights, intercept) } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala index 2289c6cdc19de..bc05b2046878f 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala @@ -185,6 +185,63 @@ class LogisticRegressionSuite extends FunSuite with LocalSparkContext with Match // Test prediction on Array. validatePrediction(validationData.map(row => model.predict(row.features)), validationData) } + + test("numerical stability of scaling features using logistic regression with LBFGS") { + /** + * If we rescale the features, the condition number will be changed so the convergence rate + * and the solution will not equal to the original solution multiple by the scaling factor + * which it should be. + * + * However, since in the LogisticRegressionWithLBFGS, we standardize the training dataset first, + * no matter how we multiple a scaling factor into the dataset, the convergence rate should be + * the same, and the solution should equal to the original solution multiple by the scaling + * factor. + */ + + val nPoints = 10000 + val A = 2.0 + val B = -1.5 + + val testData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 42) + + val initialWeights = Vectors.dense(0.0) + + val testRDD1 = sc.parallelize(testData, 2) + + val testRDD2 = sc.parallelize( + testData.map(x => LabeledPoint(x.label, Vectors.fromBreeze(x.features.toBreeze * 1.0E3))), 2) + + val testRDD3 = sc.parallelize( + testData.map(x => LabeledPoint(x.label, Vectors.fromBreeze(x.features.toBreeze * 1.0E6))), 2) + + testRDD1.cache() + testRDD2.cache() + testRDD3.cache() + + val lrA = new LogisticRegressionWithLBFGS().setIntercept(true) + val lrB = new LogisticRegressionWithLBFGS().setIntercept(true).setFeatureScaling(false) + + val modelA1 = lrA.run(testRDD1, initialWeights) + val modelA2 = lrA.run(testRDD2, initialWeights) + val modelA3 = lrA.run(testRDD3, initialWeights) + + val modelB1 = lrB.run(testRDD1, initialWeights) + val modelB2 = lrB.run(testRDD2, initialWeights) + val modelB3 = lrB.run(testRDD3, initialWeights) + + // For model trained with feature standardization, the weights should + // be the same in the scaled space. Note that the weights here are already + // in the original space, we transform back to scaled space to compare. + assert(modelA1.weights(0) ~== modelA2.weights(0) * 1.0E3 absTol 0.01) + assert(modelA1.weights(0) ~== modelA3.weights(0) * 1.0E6 absTol 0.01) + + // Training data with different scales without feature standardization + // will not yield the same result in the scaled space due to poor + // convergence rate. + assert(modelB1.weights(0) !~== modelB2.weights(0) * 1.0E3 absTol 0.1) + assert(modelB1.weights(0) !~== modelB3.weights(0) * 1.0E6 absTol 0.1) + } + } class LogisticRegressionClusterSuite extends FunSuite with LocalClusterSparkContext { From c39a3f337cfed86b3c75d90f33319498ed9a3255 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Thu, 14 Aug 2014 13:00:21 -0700 Subject: [PATCH 139/231] Revert [SPARK-3011][SQL] _temporary directory should be filtered out by sqlContext.parquetFile Reverts #1924 due to build failures with hadoop 0.23. Author: Michael Armbrust Closes #1949 from marmbrus/revert1924 and squashes the following commits: 6bff940 [Michael Armbrust] Revert "[SPARK-3011][SQL] _temporary directory should be filtered out by sqlContext.parquetFile" (cherry picked from commit a7f8a4f5ee757450ce8d4028021441435081cf53) Signed-off-by: Michael Armbrust --- .../main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala index b0579f76da073..c79a9ac2dad81 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala @@ -378,8 +378,7 @@ private[parquet] object ParquetTypesConverter extends Logging { val children = fs.listStatus(path).filterNot { status => val name = status.getPath.getName - name(0) == '.' || name == FileOutputCommitter.SUCCEEDED_FILE_NAME || - name == FileOutputCommitter.TEMP_DIR_NAME + name(0) == '.' || name == FileOutputCommitter.SUCCEEDED_FILE_NAME } // NOTE (lian): Parquet "_metadata" file can be very slow if the file consists of lots of row From f5d9176fba934fa1f440d14d1ac7cd6f149434c4 Mon Sep 17 00:00:00 2001 From: Jacek Lewandowski Date: Thu, 14 Aug 2014 15:01:39 -0700 Subject: [PATCH 140/231] SPARK-3009: Reverted readObject method in ApplicationInfo so that Applic... ...ationInfo is initialized properly after deserialization Author: Jacek Lewandowski Closes #1947 from jacek-lewandowski/master and squashes the following commits: 713b2f1 [Jacek Lewandowski] SPARK-3009: Reverted readObject method in ApplicationInfo so that ApplicationInfo is initialized properly after deserialization (cherry picked from commit a75bc7a21db07258913d038bf604c0a3c1e55b46) Signed-off-by: Andrew Or --- .../org/apache/spark/deploy/master/ApplicationInfo.scala | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala index 72d0589689e71..d3674427b1271 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala @@ -46,6 +46,11 @@ private[spark] class ApplicationInfo( init() + private def readObject(in: java.io.ObjectInputStream): Unit = { + in.defaultReadObject() + init() + } + private def init() { state = ApplicationState.WAITING executors = new mutable.HashMap[Int, ExecutorInfo] From 475a35ba4f3a641a775bb4a71481bf95e6dd3509 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 14 Aug 2014 16:27:11 -0700 Subject: [PATCH 141/231] Make dev/mima runnable on Mac OS X. Mac OS X's find is from the BSD variant that doesn't have -printf option. Author: Reynold Xin Closes #1953 from rxin/mima and squashes the following commits: e284afe [Reynold Xin] Make dev/mima runnable on Mac OS X. (cherry picked from commit fa5a08e67d1086045ac249c2090c5e4d0a17b828) Signed-off-by: Reynold Xin --- dev/mima | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/dev/mima b/dev/mima index 4c3e65039b160..09e4482af5f3d 100755 --- a/dev/mima +++ b/dev/mima @@ -26,7 +26,9 @@ cd "$FWDIR" echo -e "q\n" | sbt/sbt oldDeps/update -export SPARK_CLASSPATH=`find lib_managed \( -name '*spark*jar' -a -type f \) -printf "%p:" ` +export SPARK_CLASSPATH=`find lib_managed \( -name '*spark*jar' -a -type f \) | tr "\\n" ":"` +echo "SPARK_CLASSPATH=$SPARK_CLASSPATH" + ./bin/spark-class org.apache.spark.tools.GenerateMIMAIgnore echo -e "q\n" | sbt/sbt mima-report-binary-issues | grep -v -e "info.*Resolving" ret_val=$? From f99e4fc80615a1e0861359ab1ebc2e8335c7a022 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 14 Aug 2014 18:37:02 -0700 Subject: [PATCH 142/231] [SPARK-3027] TaskContext: tighten visibility and provide Java friendly callback API Note this also passes the TaskContext itself to the TaskCompletionListener. In the future we can mark TaskContext with the exception object if exception occurs during task execution. Author: Reynold Xin Closes #1938 from rxin/TaskContext and squashes the following commits: 145de43 [Reynold Xin] Added JavaTaskCompletionListenerImpl for Java API friendly guarantee. f435ea5 [Reynold Xin] Added license header for TaskCompletionListener. dc4ed27 [Reynold Xin] [SPARK-3027] TaskContext: tighten the visibility and provide Java friendly callback API (cherry picked from commit 655699f8b7156e8216431393436368e80626cdb2) Signed-off-by: Reynold Xin --- .../apache/spark/InterruptibleIterator.scala | 2 +- .../scala/org/apache/spark/TaskContext.scala | 63 ++++++++++++++++--- .../apache/spark/api/python/PythonRDD.scala | 12 ++-- .../org/apache/spark/rdd/CheckpointRDD.scala | 2 +- .../org/apache/spark/rdd/HadoopRDD.scala | 2 +- .../scala/org/apache/spark/rdd/JdbcRDD.scala | 2 +- .../org/apache/spark/rdd/NewHadoopRDD.scala | 2 +- .../apache/spark/scheduler/DAGScheduler.scala | 2 +- .../apache/spark/scheduler/ResultTask.scala | 2 +- .../spark/scheduler/ShuffleMapTask.scala | 2 +- .../org/apache/spark/scheduler/Task.scala | 2 +- .../spark/util/TaskCompletionListener.scala | 33 ++++++++++ .../util/JavaTaskCompletionListenerImpl.java | 39 ++++++++++++ .../spark/scheduler/TaskContextSuite.scala | 2 +- 14 files changed, 144 insertions(+), 23 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/util/TaskCompletionListener.scala create mode 100644 core/src/test/java/org/apache/spark/util/JavaTaskCompletionListenerImpl.java diff --git a/core/src/main/scala/org/apache/spark/InterruptibleIterator.scala b/core/src/main/scala/org/apache/spark/InterruptibleIterator.scala index f40baa8e43592..5c262bcbddf76 100644 --- a/core/src/main/scala/org/apache/spark/InterruptibleIterator.scala +++ b/core/src/main/scala/org/apache/spark/InterruptibleIterator.scala @@ -33,7 +33,7 @@ class InterruptibleIterator[+T](val context: TaskContext, val delegate: Iterator // is allowed. The assumption is that Thread.interrupted does not have a memory fence in read // (just a volatile field in C), while context.interrupted is a volatile in the JVM, which // introduces an expensive read fence. - if (context.interrupted) { + if (context.isInterrupted) { throw new TaskKilledException } else { delegate.hasNext diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala index 51f40c339d13c..2b99b8a5af250 100644 --- a/core/src/main/scala/org/apache/spark/TaskContext.scala +++ b/core/src/main/scala/org/apache/spark/TaskContext.scala @@ -21,10 +21,18 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.annotation.DeveloperApi import org.apache.spark.executor.TaskMetrics +import org.apache.spark.util.TaskCompletionListener + /** * :: DeveloperApi :: * Contextual information about a task which can be read or mutated during execution. + * + * @param stageId stage id + * @param partitionId index of the partition + * @param attemptId the number of attempts to execute this task + * @param runningLocally whether the task is running locally in the driver JVM + * @param taskMetrics performance metrics of the task */ @DeveloperApi class TaskContext( @@ -39,13 +47,45 @@ class TaskContext( def splitId = partitionId // List of callback functions to execute when the task completes. - @transient private val onCompleteCallbacks = new ArrayBuffer[() => Unit] + @transient private val onCompleteCallbacks = new ArrayBuffer[TaskCompletionListener] // Whether the corresponding task has been killed. - @volatile var interrupted: Boolean = false + @volatile private var interrupted: Boolean = false + + // Whether the task has completed. + @volatile private var completed: Boolean = false + + /** Checks whether the task has completed. */ + def isCompleted: Boolean = completed - // Whether the task has completed, before the onCompleteCallbacks are executed. - @volatile var completed: Boolean = false + /** Checks whether the task has been killed. */ + def isInterrupted: Boolean = interrupted + + // TODO: Also track whether the task has completed successfully or with exception. + + /** + * Add a (Java friendly) listener to be executed on task completion. + * This will be called in all situation - success, failure, or cancellation. + * + * An example use is for HadoopRDD to register a callback to close the input stream. + */ + def addTaskCompletionListener(listener: TaskCompletionListener): this.type = { + onCompleteCallbacks += listener + this + } + + /** + * Add a listener in the form of a Scala closure to be executed on task completion. + * This will be called in all situation - success, failure, or cancellation. + * + * An example use is for HadoopRDD to register a callback to close the input stream. + */ + def addTaskCompletionListener(f: TaskContext => Unit): this.type = { + onCompleteCallbacks += new TaskCompletionListener { + override def onTaskCompletion(context: TaskContext): Unit = f(context) + } + this + } /** * Add a callback function to be executed on task completion. An example use @@ -53,13 +93,22 @@ class TaskContext( * Will be called in any situation - success, failure, or cancellation. * @param f Callback function. */ + @deprecated("use addTaskCompletionListener", "1.1.0") def addOnCompleteCallback(f: () => Unit) { - onCompleteCallbacks += f + onCompleteCallbacks += new TaskCompletionListener { + override def onTaskCompletion(context: TaskContext): Unit = f() + } } - def executeOnCompleteCallbacks() { + /** Marks the task as completed and triggers the listeners. */ + private[spark] def markTaskCompleted(): Unit = { completed = true // Process complete callbacks in the reverse order of registration - onCompleteCallbacks.reverse.foreach { _() } + onCompleteCallbacks.reverse.foreach { _.onTaskCompletion(this) } + } + + /** Marks the task for interruption, i.e. cancellation. */ + private[spark] def markInterrupted(): Unit = { + interrupted = true } } diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 0b5322c6fb965..fefe1cb6f134c 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -68,7 +68,7 @@ private[spark] class PythonRDD( // Start a thread to feed the process input from our parent's iterator val writerThread = new WriterThread(env, worker, split, context) - context.addOnCompleteCallback { () => + context.addTaskCompletionListener { context => writerThread.shutdownOnTaskCompletion() // Cleanup the worker socket. This will also cause the Python worker to exit. @@ -137,7 +137,7 @@ private[spark] class PythonRDD( } } catch { - case e: Exception if context.interrupted => + case e: Exception if context.isInterrupted => logDebug("Exception thrown after task interruption", e) throw new TaskKilledException @@ -176,7 +176,7 @@ private[spark] class PythonRDD( /** Terminates the writer thread, ignoring any exceptions that may occur due to cleanup. */ def shutdownOnTaskCompletion() { - assert(context.completed) + assert(context.isCompleted) this.interrupt() } @@ -209,7 +209,7 @@ private[spark] class PythonRDD( PythonRDD.writeIteratorToStream(parent.iterator(split, context), dataOut) dataOut.flush() } catch { - case e: Exception if context.completed || context.interrupted => + case e: Exception if context.isCompleted || context.isInterrupted => logDebug("Exception thrown after task completion (likely due to cleanup)", e) case e: Exception => @@ -235,10 +235,10 @@ private[spark] class PythonRDD( override def run() { // Kill the worker if it is interrupted, checking until task completion. // TODO: This has a race condition if interruption occurs, as completed may still become true. - while (!context.interrupted && !context.completed) { + while (!context.isInterrupted && !context.isCompleted) { Thread.sleep(2000) } - if (!context.completed) { + if (!context.isCompleted) { try { logWarning("Incomplete task interrupted: Attempting to kill Python Worker") env.destroyPythonWorker(pythonExec, envVars.toMap, worker) diff --git a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala index 34c51b833025e..20938781ac694 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala @@ -141,7 +141,7 @@ private[spark] object CheckpointRDD extends Logging { val deserializeStream = serializer.deserializeStream(fileInputStream) // Register an on-task-completion callback to close the input stream. - context.addOnCompleteCallback(() => deserializeStream.close()) + context.addTaskCompletionListener(context => deserializeStream.close()) deserializeStream.asIterator.asInstanceOf[Iterator[T]] } diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index 8d92ea01d9a3f..c8623314c98eb 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -197,7 +197,7 @@ class HadoopRDD[K, V]( reader = inputFormat.getRecordReader(split.inputSplit.value, jobConf, Reporter.NULL) // Register an on-task-completion callback to close the input stream. - context.addOnCompleteCallback{ () => closeIfNeeded() } + context.addTaskCompletionListener{ context => closeIfNeeded() } val key: K = reader.createKey() val value: V = reader.createValue() diff --git a/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala b/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala index 8947e66f4577c..0e38f224ac81d 100644 --- a/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala @@ -68,7 +68,7 @@ class JdbcRDD[T: ClassTag]( } override def compute(thePart: Partition, context: TaskContext) = new NextIterator[T] { - context.addOnCompleteCallback{ () => closeIfNeeded() } + context.addTaskCompletionListener{ context => closeIfNeeded() } val part = thePart.asInstanceOf[JdbcPartition] val conn = getConnection() val stmt = conn.prepareStatement(sql, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY) diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index 7dfec9a18ec67..58f707b9b4634 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -129,7 +129,7 @@ class NewHadoopRDD[K, V]( context.taskMetrics.inputMetrics = Some(inputMetrics) // Register an on-task-completion callback to close the input stream. - context.addOnCompleteCallback(() => close()) + context.addTaskCompletionListener(context => close()) var havePair = false var finished = false diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 36bbaaa3f1c85..b86cfbfa48fbe 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -634,7 +634,7 @@ class DAGScheduler( val result = job.func(taskContext, rdd.iterator(split, taskContext)) job.listener.taskSucceeded(0, result) } finally { - taskContext.executeOnCompleteCallbacks() + taskContext.markTaskCompleted() } } catch { case e: Exception => diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala index d09fd7aa57642..2ccbd8edeb028 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala @@ -61,7 +61,7 @@ private[spark] class ResultTask[T, U]( try { func(context, rdd.iterator(partition, context)) } finally { - context.executeOnCompleteCallbacks() + context.markTaskCompleted() } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala index 11255c07469d4..381eff2147e95 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala @@ -74,7 +74,7 @@ private[spark] class ShuffleMapTask( } throw e } finally { - context.executeOnCompleteCallbacks() + context.markTaskCompleted() } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index cbe0bc0bcb0a5..6aa0cca06878d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -87,7 +87,7 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex def kill(interruptThread: Boolean) { _killed = true if (context != null) { - context.interrupted = true + context.markInterrupted() } if (interruptThread && taskThread != null) { taskThread.interrupt() diff --git a/core/src/main/scala/org/apache/spark/util/TaskCompletionListener.scala b/core/src/main/scala/org/apache/spark/util/TaskCompletionListener.scala new file mode 100644 index 0000000000000..c1b8bf052c0ca --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/TaskCompletionListener.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import java.util.EventListener + +import org.apache.spark.TaskContext +import org.apache.spark.annotation.DeveloperApi + +/** + * :: DeveloperApi :: + * + * Listener providing a callback function to invoke when a task's execution completes. + */ +@DeveloperApi +trait TaskCompletionListener extends EventListener { + def onTaskCompletion(context: TaskContext) +} diff --git a/core/src/test/java/org/apache/spark/util/JavaTaskCompletionListenerImpl.java b/core/src/test/java/org/apache/spark/util/JavaTaskCompletionListenerImpl.java new file mode 100644 index 0000000000000..af34cdb03e4d1 --- /dev/null +++ b/core/src/test/java/org/apache/spark/util/JavaTaskCompletionListenerImpl.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util; + +import org.apache.spark.TaskContext; + + +/** + * A simple implementation of TaskCompletionListener that makes sure TaskCompletionListener and + * TaskContext is Java friendly. + */ +public class JavaTaskCompletionListenerImpl implements TaskCompletionListener { + + @Override + public void onTaskCompletion(TaskContext context) { + context.isCompleted(); + context.isInterrupted(); + context.stageId(); + context.partitionId(); + context.runningLocally(); + context.taskMetrics(); + context.addTaskCompletionListener(this); + } +} diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala index 270f7e661045a..db2ad829a48f9 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala @@ -32,7 +32,7 @@ class TaskContextSuite extends FunSuite with BeforeAndAfter with LocalSparkConte val rdd = new RDD[String](sc, List()) { override def getPartitions = Array[Partition](StubPartition(0)) override def compute(split: Partition, context: TaskContext) = { - context.addOnCompleteCallback(() => TaskContextSuite.completed = true) + context.addTaskCompletionListener(context => TaskContextSuite.completed = true) sys.error("failed") } } From 72e730e9828bb3d88c69a36a241c2e332fca5629 Mon Sep 17 00:00:00 2001 From: Kan Zhang Date: Thu, 14 Aug 2014 19:03:51 -0700 Subject: [PATCH 143/231] [SPARK-2736] PySpark converter and example script for reading Avro files JIRA: https://issues.apache.org/jira/browse/SPARK-2736 This patch includes: 1. An Avro converter that converts Avro data types to Python. It handles all 3 Avro data mappings (Generic, Specific and Reflect). 2. An example Python script for reading Avro files using AvroKeyInputFormat and the converter. 3. Fixing a classloading issue. cc @MLnick @JoshRosen @mateiz Author: Kan Zhang Closes #1916 from kanzhang/SPARK-2736 and squashes the following commits: 02443f8 [Kan Zhang] [SPARK-2736] Adding .avsc files to .rat-excludes f74e9a9 [Kan Zhang] [SPARK-2736] nit: clazz -> className 82cc505 [Kan Zhang] [SPARK-2736] Update data sample 0be7761 [Kan Zhang] [SPARK-2736] Example pyspark script and data files c8e5881 [Kan Zhang] [SPARK-2736] Trying to work with all 3 Avro data models 2271a5b [Kan Zhang] [SPARK-2736] Using the right class loader to find Avro classes 536876b [Kan Zhang] [SPARK-2736] Adding Avro to Java converter (cherry picked from commit 9422a9b084e3fd5b2b9be2752013588adfb430d0) Signed-off-by: Matei Zaharia --- .rat-excludes | 1 + .../spark/api/python/PythonHadoopUtil.scala | 3 +- .../apache/spark/api/python/PythonRDD.scala | 24 ++-- .../scala/org/apache/spark/util/Utils.scala | 3 + examples/src/main/python/avro_inputformat.py | 75 ++++++++++ examples/src/main/resources/user.avsc | 8 ++ examples/src/main/resources/users.avro | Bin 0 -> 334 bytes .../pythonconverters/AvroConverters.scala | 130 ++++++++++++++++++ 8 files changed, 231 insertions(+), 13 deletions(-) create mode 100644 examples/src/main/python/avro_inputformat.py create mode 100644 examples/src/main/resources/user.avsc create mode 100644 examples/src/main/resources/users.avro create mode 100644 examples/src/main/scala/org/apache/spark/examples/pythonconverters/AvroConverters.scala diff --git a/.rat-excludes b/.rat-excludes index bccb043c2bb55..eaefef1b0aa2e 100644 --- a/.rat-excludes +++ b/.rat-excludes @@ -25,6 +25,7 @@ log4j-defaults.properties bootstrap-tooltip.js jquery-1.11.1.min.js sorttable.js +.*avsc .*txt .*json .*data diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala b/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala index f3b05e1243045..49dc95f349eac 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala @@ -19,6 +19,7 @@ package org.apache.spark.api.python import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD +import org.apache.spark.util.Utils import org.apache.spark.{Logging, SerializableWritable, SparkException} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.io._ @@ -42,7 +43,7 @@ private[python] object Converter extends Logging { defaultConverter: Converter[Any, Any]): Converter[Any, Any] = { converterClass.map { cc => Try { - val c = Class.forName(cc).newInstance().asInstanceOf[Converter[Any, Any]] + val c = Utils.classForName(cc).newInstance().asInstanceOf[Converter[Any, Any]] logInfo(s"Loaded converter: $cc") c } match { diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index fefe1cb6f134c..9f5c5bd30f0c9 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -372,8 +372,8 @@ private[spark] object PythonRDD extends Logging { batchSize: Int) = { val keyClass = Option(keyClassMaybeNull).getOrElse("org.apache.hadoop.io.Text") val valueClass = Option(valueClassMaybeNull).getOrElse("org.apache.hadoop.io.Text") - val kc = Class.forName(keyClass).asInstanceOf[Class[K]] - val vc = Class.forName(valueClass).asInstanceOf[Class[V]] + val kc = Utils.classForName(keyClass).asInstanceOf[Class[K]] + val vc = Utils.classForName(valueClass).asInstanceOf[Class[V]] val rdd = sc.sc.sequenceFile[K, V](path, kc, vc, minSplits) val confBroadcasted = sc.sc.broadcast(new SerializableWritable(sc.hadoopConfiguration())) val converted = convertRDD(rdd, keyConverterClass, valueConverterClass, @@ -440,9 +440,9 @@ private[spark] object PythonRDD extends Logging { keyClass: String, valueClass: String, conf: Configuration) = { - val kc = Class.forName(keyClass).asInstanceOf[Class[K]] - val vc = Class.forName(valueClass).asInstanceOf[Class[V]] - val fc = Class.forName(inputFormatClass).asInstanceOf[Class[F]] + val kc = Utils.classForName(keyClass).asInstanceOf[Class[K]] + val vc = Utils.classForName(valueClass).asInstanceOf[Class[V]] + val fc = Utils.classForName(inputFormatClass).asInstanceOf[Class[F]] if (path.isDefined) { sc.sc.newAPIHadoopFile[K, V, F](path.get, fc, kc, vc, conf) } else { @@ -509,9 +509,9 @@ private[spark] object PythonRDD extends Logging { keyClass: String, valueClass: String, conf: Configuration) = { - val kc = Class.forName(keyClass).asInstanceOf[Class[K]] - val vc = Class.forName(valueClass).asInstanceOf[Class[V]] - val fc = Class.forName(inputFormatClass).asInstanceOf[Class[F]] + val kc = Utils.classForName(keyClass).asInstanceOf[Class[K]] + val vc = Utils.classForName(valueClass).asInstanceOf[Class[V]] + val fc = Utils.classForName(inputFormatClass).asInstanceOf[Class[F]] if (path.isDefined) { sc.sc.hadoopFile(path.get, fc, kc, vc) } else { @@ -558,7 +558,7 @@ private[spark] object PythonRDD extends Logging { for { k <- Option(keyClass) v <- Option(valueClass) - } yield (Class.forName(k), Class.forName(v)) + } yield (Utils.classForName(k), Utils.classForName(v)) } private def getKeyValueConverters(keyConverterClass: String, valueConverterClass: String, @@ -621,10 +621,10 @@ private[spark] object PythonRDD extends Logging { val (kc, vc) = getKeyValueTypes(keyClass, valueClass).getOrElse( inferKeyValueTypes(rdd, keyConverterClass, valueConverterClass)) val mergedConf = getMergedConf(confAsMap, pyRDD.context.hadoopConfiguration) - val codec = Option(compressionCodecClass).map(Class.forName(_).asInstanceOf[Class[C]]) + val codec = Option(compressionCodecClass).map(Utils.classForName(_).asInstanceOf[Class[C]]) val converted = convertRDD(rdd, keyConverterClass, valueConverterClass, new JavaToWritableConverter) - val fc = Class.forName(outputFormatClass).asInstanceOf[Class[F]] + val fc = Utils.classForName(outputFormatClass).asInstanceOf[Class[F]] converted.saveAsHadoopFile(path, kc, vc, fc, new JobConf(mergedConf), codec=codec) } @@ -653,7 +653,7 @@ private[spark] object PythonRDD extends Logging { val mergedConf = getMergedConf(confAsMap, pyRDD.context.hadoopConfiguration) val converted = convertRDD(rdd, keyConverterClass, valueConverterClass, new JavaToWritableConverter) - val fc = Class.forName(outputFormatClass).asInstanceOf[Class[F]] + val fc = Utils.classForName(outputFormatClass).asInstanceOf[Class[F]] converted.saveAsNewAPIHadoopFile(path, kc, vc, fc, mergedConf) } diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 8cac5da644fa9..019f68b160894 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -146,6 +146,9 @@ private[spark] object Utils extends Logging { Try { Class.forName(clazz, false, getContextOrSparkClassLoader) }.isSuccess } + /** Preferred alternative to Class.forName(className) */ + def classForName(className: String) = Class.forName(className, true, getContextOrSparkClassLoader) + /** * Primitive often used when writing {@link java.nio.ByteBuffer} to {@link java.io.DataOutput}. */ diff --git a/examples/src/main/python/avro_inputformat.py b/examples/src/main/python/avro_inputformat.py new file mode 100644 index 0000000000000..e902ae29753c0 --- /dev/null +++ b/examples/src/main/python/avro_inputformat.py @@ -0,0 +1,75 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import sys + +from pyspark import SparkContext + +""" +Read data file users.avro in local Spark distro: + +$ cd $SPARK_HOME +$ ./bin/spark-submit --driver-class-path /path/to/example/jar ./examples/src/main/python/avro_inputformat.py \ +> examples/src/main/resources/users.avro +{u'favorite_color': None, u'name': u'Alyssa', u'favorite_numbers': [3, 9, 15, 20]} +{u'favorite_color': u'red', u'name': u'Ben', u'favorite_numbers': []} + +To read name and favorite_color fields only, specify the following reader schema: + +$ cat examples/src/main/resources/user.avsc +{"namespace": "example.avro", + "type": "record", + "name": "User", + "fields": [ + {"name": "name", "type": "string"}, + {"name": "favorite_color", "type": ["string", "null"]} + ] +} + +$ ./bin/spark-submit --driver-class-path /path/to/example/jar ./examples/src/main/python/avro_inputformat.py \ +> examples/src/main/resources/users.avro examples/src/main/resources/user.avsc +{u'favorite_color': None, u'name': u'Alyssa'} +{u'favorite_color': u'red', u'name': u'Ben'} +""" +if __name__ == "__main__": + if len(sys.argv) != 2 and len(sys.argv) != 3: + print >> sys.stderr, """ + Usage: avro_inputformat [reader_schema_file] + + Run with example jar: + ./bin/spark-submit --driver-class-path /path/to/example/jar /path/to/examples/avro_inputformat.py [reader_schema_file] + Assumes you have Avro data stored in . Reader schema can be optionally specified in [reader_schema_file]. + """ + exit(-1) + + path = sys.argv[1] + sc = SparkContext(appName="AvroKeyInputFormat") + + conf = None + if len(sys.argv) == 3: + schema_rdd = sc.textFile(sys.argv[2], 1).collect() + conf = {"avro.schema.input.key" : reduce(lambda x, y: x+y, schema_rdd)} + + avro_rdd = sc.newAPIHadoopFile(path, + "org.apache.avro.mapreduce.AvroKeyInputFormat", + "org.apache.avro.mapred.AvroKey", + "org.apache.hadoop.io.NullWritable", + keyConverter="org.apache.spark.examples.pythonconverters.AvroWrapperToJavaConverter", + conf=conf) + output = avro_rdd.map(lambda x: x[0]).collect() + for k in output: + print k diff --git a/examples/src/main/resources/user.avsc b/examples/src/main/resources/user.avsc new file mode 100644 index 0000000000000..4995357ab3736 --- /dev/null +++ b/examples/src/main/resources/user.avsc @@ -0,0 +1,8 @@ +{"namespace": "example.avro", + "type": "record", + "name": "User", + "fields": [ + {"name": "name", "type": "string"}, + {"name": "favorite_color", "type": ["string", "null"]} + ] +} diff --git a/examples/src/main/resources/users.avro b/examples/src/main/resources/users.avro new file mode 100644 index 0000000000000000000000000000000000000000..27c526ab114b2f42f6d4e13325c373706ba0f880 GIT binary patch literal 334 zcmeZI%3@>@ODrqO*DFrWNX<=rz+A0VQdy9yWTl`~l$xAhl%k}gpp=)Gn_66um<$$9 ztw_u*$Vt@$>4Hgul!q3l7J>L_nW;G`#Xym0gi*yMMVWc&$f`j`D%I*Jz|}-6At@@& z$x(`hS`0EfEwL=WD6=FrJ~=-pzX(NNwGvP~7i6DOW?l)%3Yhy7i;5B}L2AM7M=>U^ zG&d==s932swpIk}`{ewT)MSo4puG%vlk4vPb+WF0^sw`-e)omlECxJ|IhDo5iA)@9 TLUI}mY)+|p3~WWIDHtjNiNSH? literal 0 HcmV?d00001 diff --git a/examples/src/main/scala/org/apache/spark/examples/pythonconverters/AvroConverters.scala b/examples/src/main/scala/org/apache/spark/examples/pythonconverters/AvroConverters.scala new file mode 100644 index 0000000000000..1b25983a38453 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/pythonconverters/AvroConverters.scala @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.pythonconverters + +import java.util.{Collection => JCollection, Map => JMap} + +import scala.collection.JavaConversions._ + +import org.apache.avro.generic.{GenericFixed, IndexedRecord} +import org.apache.avro.mapred.AvroWrapper +import org.apache.avro.Schema +import org.apache.avro.Schema.Type._ + +import org.apache.spark.api.python.Converter +import org.apache.spark.SparkException + + +/** + * Implementation of [[org.apache.spark.api.python.Converter]] that converts + * an Avro Record wrapped in an AvroKey (or AvroValue) to a Java Map. It tries + * to work with all 3 Avro data mappings (Generic, Specific and Reflect). + */ +class AvroWrapperToJavaConverter extends Converter[Any, Any] { + override def convert(obj: Any): Any = { + if (obj == null) { + return null + } + obj.asInstanceOf[AvroWrapper[_]].datum() match { + case null => null + case record: IndexedRecord => unpackRecord(record) + case other => throw new SparkException( + s"Unsupported top-level Avro data type ${other.getClass.getName}") + } + } + + def unpackRecord(obj: Any): JMap[String, Any] = { + val map = new java.util.HashMap[String, Any] + obj match { + case record: IndexedRecord => + record.getSchema.getFields.zipWithIndex.foreach { case (f, i) => + map.put(f.name, fromAvro(record.get(i), f.schema)) + } + case other => throw new SparkException( + s"Unsupported RECORD type ${other.getClass.getName}") + } + map + } + + def unpackMap(obj: Any, schema: Schema): JMap[String, Any] = { + obj.asInstanceOf[JMap[_, _]].map { case (key, value) => + (key.toString, fromAvro(value, schema.getValueType)) + } + } + + def unpackFixed(obj: Any, schema: Schema): Array[Byte] = { + unpackBytes(obj.asInstanceOf[GenericFixed].bytes()) + } + + def unpackBytes(obj: Any): Array[Byte] = { + val bytes: Array[Byte] = obj match { + case buf: java.nio.ByteBuffer => buf.array() + case arr: Array[Byte] => arr + case other => throw new SparkException( + s"Unknown BYTES type ${other.getClass.getName}") + } + val bytearray = new Array[Byte](bytes.length) + System.arraycopy(bytes, 0, bytearray, 0, bytes.length) + bytearray + } + + def unpackArray(obj: Any, schema: Schema): JCollection[Any] = obj match { + case c: JCollection[_] => + c.map(fromAvro(_, schema.getElementType)) + case arr: Array[_] if arr.getClass.getComponentType.isPrimitive => + arr.toSeq + case arr: Array[_] => + arr.map(fromAvro(_, schema.getElementType)).toSeq + case other => throw new SparkException( + s"Unknown ARRAY type ${other.getClass.getName}") + } + + def unpackUnion(obj: Any, schema: Schema): Any = { + schema.getTypes.toList match { + case List(s) => fromAvro(obj, s) + case List(n, s) if n.getType == NULL => fromAvro(obj, s) + case List(s, n) if n.getType == NULL => fromAvro(obj, s) + case _ => throw new SparkException( + "Unions may only consist of a concrete type and null") + } + } + + def fromAvro(obj: Any, schema: Schema): Any = { + if (obj == null) { + return null + } + schema.getType match { + case UNION => unpackUnion(obj, schema) + case ARRAY => unpackArray(obj, schema) + case FIXED => unpackFixed(obj, schema) + case MAP => unpackMap(obj, schema) + case BYTES => unpackBytes(obj) + case RECORD => unpackRecord(obj) + case STRING => obj.toString + case ENUM => obj.toString + case NULL => obj + case BOOLEAN => obj + case DOUBLE => obj + case FLOAT => obj + case INT => obj + case LONG => obj + case other => throw new SparkException( + s"Unknown Avro schema type ${other.getName}") + } + } +} From d3cce5821ebdbe1e6a91bf7fe1efc00c23e62b08 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sun, 10 Aug 2014 20:36:54 -0700 Subject: [PATCH 144/231] [SPARK-2936] Migrate Netty network module from Java to Scala The Netty network module was originally written when Scala 2.9.x had a bug that prevents a pure Scala implementation, and a subset of the files were done in Java. We have since upgraded to Scala 2.10, and can migrate all Java files now to Scala. https://github.com/netty/netty/issues/781 https://github.com/mesos/spark/pull/522 Author: Reynold Xin Closes #1865 from rxin/netty and squashes the following commits: 332422f [Reynold Xin] Code review feedback ca9eeee [Reynold Xin] Minor update. 7f1434b [Reynold Xin] [SPARK-2936] Migrate Netty network module from Java to Scala (cherry picked from commit ba28a8fcbc3ba432e7ea4d6f0b535450a6ec96c6) Signed-off-by: Reynold Xin --- .../spark/network/netty/FileClient.java | 100 ---------------- .../spark/network/netty/FileServer.java | 111 ------------------ .../network/netty/FileServerHandler.java | 83 ------------- .../spark/network/netty/FileClient.scala | 85 ++++++++++++++ .../netty/FileClientChannelInitializer.scala} | 24 ++-- .../network/netty/FileClientHandler.scala} | 47 ++++---- .../spark/network/netty/FileHeader.scala | 5 +- .../spark/network/netty/FileServer.scala | 91 ++++++++++++++ .../netty/FileServerChannelInitializer.scala} | 31 ++--- .../network/netty/FileServerHandler.scala | 68 +++++++++++ .../spark/network/netty/PathResolver.scala} | 9 +- .../spark/network/netty/ShuffleSender.scala | 2 +- 12 files changed, 292 insertions(+), 364 deletions(-) delete mode 100644 core/src/main/java/org/apache/spark/network/netty/FileClient.java delete mode 100644 core/src/main/java/org/apache/spark/network/netty/FileServer.java delete mode 100644 core/src/main/java/org/apache/spark/network/netty/FileServerHandler.java create mode 100644 core/src/main/scala/org/apache/spark/network/netty/FileClient.scala rename core/src/main/{java/org/apache/spark/network/netty/FileClientChannelInitializer.java => scala/org/apache/spark/network/netty/FileClientChannelInitializer.scala} (57%) rename core/src/main/{java/org/apache/spark/network/netty/FileClientHandler.java => scala/org/apache/spark/network/netty/FileClientHandler.scala} (51%) create mode 100644 core/src/main/scala/org/apache/spark/network/netty/FileServer.scala rename core/src/main/{java/org/apache/spark/network/netty/FileServerChannelInitializer.java => scala/org/apache/spark/network/netty/FileServerChannelInitializer.scala} (54%) create mode 100644 core/src/main/scala/org/apache/spark/network/netty/FileServerHandler.scala rename core/src/main/{java/org/apache/spark/network/netty/PathResolver.java => scala/org/apache/spark/network/netty/PathResolver.scala} (80%) mode change 100755 => 100644 diff --git a/core/src/main/java/org/apache/spark/network/netty/FileClient.java b/core/src/main/java/org/apache/spark/network/netty/FileClient.java deleted file mode 100644 index 0d31894d6ec7a..0000000000000 --- a/core/src/main/java/org/apache/spark/network/netty/FileClient.java +++ /dev/null @@ -1,100 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty; - -import java.util.concurrent.TimeUnit; - -import io.netty.bootstrap.Bootstrap; -import io.netty.channel.Channel; -import io.netty.channel.ChannelOption; -import io.netty.channel.EventLoopGroup; -import io.netty.channel.oio.OioEventLoopGroup; -import io.netty.channel.socket.oio.OioSocketChannel; - -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -class FileClient { - - private static final Logger LOG = LoggerFactory.getLogger(FileClient.class.getName()); - - private final FileClientHandler handler; - private Channel channel = null; - private Bootstrap bootstrap = null; - private EventLoopGroup group = null; - private final int connectTimeout; - private final int sendTimeout = 60; // 1 min - - FileClient(FileClientHandler handler, int connectTimeout) { - this.handler = handler; - this.connectTimeout = connectTimeout; - } - - public void init() { - group = new OioEventLoopGroup(); - bootstrap = new Bootstrap(); - bootstrap.group(group) - .channel(OioSocketChannel.class) - .option(ChannelOption.SO_KEEPALIVE, true) - .option(ChannelOption.TCP_NODELAY, true) - .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, connectTimeout) - .handler(new FileClientChannelInitializer(handler)); - } - - public void connect(String host, int port) { - try { - // Start the connection attempt. - channel = bootstrap.connect(host, port).sync().channel(); - // ChannelFuture cf = channel.closeFuture(); - //cf.addListener(new ChannelCloseListener(this)); - } catch (InterruptedException e) { - LOG.warn("FileClient interrupted while trying to connect", e); - close(); - } - } - - public void waitForClose() { - try { - channel.closeFuture().sync(); - } catch (InterruptedException e) { - LOG.warn("FileClient interrupted", e); - } - } - - public void sendRequest(String file) { - //assert(file == null); - //assert(channel == null); - try { - // Should be able to send the message to network link channel. - boolean bSent = channel.writeAndFlush(file + "\r\n").await(sendTimeout, TimeUnit.SECONDS); - if (!bSent) { - throw new RuntimeException("Failed to send"); - } - } catch (InterruptedException e) { - LOG.error("Error", e); - } - } - - public void close() { - if (group != null) { - group.shutdownGracefully(); - group = null; - bootstrap = null; - } - } -} diff --git a/core/src/main/java/org/apache/spark/network/netty/FileServer.java b/core/src/main/java/org/apache/spark/network/netty/FileServer.java deleted file mode 100644 index c93425e2787dc..0000000000000 --- a/core/src/main/java/org/apache/spark/network/netty/FileServer.java +++ /dev/null @@ -1,111 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty; - -import java.net.InetSocketAddress; - -import io.netty.bootstrap.ServerBootstrap; -import io.netty.channel.ChannelFuture; -import io.netty.channel.ChannelOption; -import io.netty.channel.EventLoopGroup; -import io.netty.channel.oio.OioEventLoopGroup; -import io.netty.channel.socket.oio.OioServerSocketChannel; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -/** - * Server that accept the path of a file an echo back its content. - */ -class FileServer { - - private static final Logger LOG = LoggerFactory.getLogger(FileServer.class.getName()); - - private EventLoopGroup bossGroup = null; - private EventLoopGroup workerGroup = null; - private ChannelFuture channelFuture = null; - private int port = 0; - - FileServer(PathResolver pResolver, int port) { - InetSocketAddress addr = new InetSocketAddress(port); - - // Configure the server. - bossGroup = new OioEventLoopGroup(); - workerGroup = new OioEventLoopGroup(); - - ServerBootstrap bootstrap = new ServerBootstrap(); - bootstrap.group(bossGroup, workerGroup) - .channel(OioServerSocketChannel.class) - .option(ChannelOption.SO_BACKLOG, 100) - .option(ChannelOption.SO_RCVBUF, 1500) - .childHandler(new FileServerChannelInitializer(pResolver)); - // Start the server. - channelFuture = bootstrap.bind(addr); - try { - // Get the address we bound to. - InetSocketAddress boundAddress = - ((InetSocketAddress) channelFuture.sync().channel().localAddress()); - this.port = boundAddress.getPort(); - } catch (InterruptedException ie) { - this.port = 0; - } - } - - /** - * Start the file server asynchronously in a new thread. - */ - public void start() { - Thread blockingThread = new Thread() { - @Override - public void run() { - try { - channelFuture.channel().closeFuture().sync(); - LOG.info("FileServer exiting"); - } catch (InterruptedException e) { - LOG.error("File server start got interrupted", e); - } - // NOTE: bootstrap is shutdown in stop() - } - }; - blockingThread.setDaemon(true); - blockingThread.start(); - } - - public int getPort() { - return port; - } - - public void stop() { - // Close the bound channel. - if (channelFuture != null) { - channelFuture.channel().close().awaitUninterruptibly(); - channelFuture = null; - } - - // Shutdown event groups - if (bossGroup != null) { - bossGroup.shutdownGracefully(); - bossGroup = null; - } - - if (workerGroup != null) { - workerGroup.shutdownGracefully(); - workerGroup = null; - } - // TODO: Shutdown all accepted channels as well ? - } -} diff --git a/core/src/main/java/org/apache/spark/network/netty/FileServerHandler.java b/core/src/main/java/org/apache/spark/network/netty/FileServerHandler.java deleted file mode 100644 index c0133e19c7f79..0000000000000 --- a/core/src/main/java/org/apache/spark/network/netty/FileServerHandler.java +++ /dev/null @@ -1,83 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty; - -import java.io.File; -import java.io.FileInputStream; - -import io.netty.channel.ChannelHandlerContext; -import io.netty.channel.SimpleChannelInboundHandler; -import io.netty.channel.DefaultFileRegion; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import org.apache.spark.storage.BlockId; -import org.apache.spark.storage.FileSegment; - -class FileServerHandler extends SimpleChannelInboundHandler { - - private static final Logger LOG = LoggerFactory.getLogger(FileServerHandler.class.getName()); - - private final PathResolver pResolver; - - FileServerHandler(PathResolver pResolver){ - this.pResolver = pResolver; - } - - @Override - public void channelRead0(ChannelHandlerContext ctx, String blockIdString) { - BlockId blockId = BlockId.apply(blockIdString); - FileSegment fileSegment = pResolver.getBlockLocation(blockId); - // if getBlockLocation returns null, close the channel - if (fileSegment == null) { - //ctx.close(); - return; - } - File file = fileSegment.file(); - if (file.exists()) { - if (!file.isFile()) { - ctx.write(new FileHeader(0, blockId).buffer()); - ctx.flush(); - return; - } - long length = fileSegment.length(); - if (length > Integer.MAX_VALUE || length <= 0) { - ctx.write(new FileHeader(0, blockId).buffer()); - ctx.flush(); - return; - } - int len = (int) length; - ctx.write((new FileHeader(len, blockId)).buffer()); - try { - ctx.write(new DefaultFileRegion(new FileInputStream(file) - .getChannel(), fileSegment.offset(), fileSegment.length())); - } catch (Exception e) { - LOG.error("Exception: ", e); - } - } else { - ctx.write(new FileHeader(0, blockId).buffer()); - } - ctx.flush(); - } - - @Override - public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { - LOG.error("Exception: ", cause); - ctx.close(); - } -} diff --git a/core/src/main/scala/org/apache/spark/network/netty/FileClient.scala b/core/src/main/scala/org/apache/spark/network/netty/FileClient.scala new file mode 100644 index 0000000000000..c6d35f73db545 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/netty/FileClient.scala @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty + +import java.util.concurrent.TimeUnit + +import io.netty.bootstrap.Bootstrap +import io.netty.channel.{Channel, ChannelOption, EventLoopGroup} +import io.netty.channel.oio.OioEventLoopGroup +import io.netty.channel.socket.oio.OioSocketChannel + +import org.apache.spark.Logging + +class FileClient(handler: FileClientHandler, connectTimeout: Int) extends Logging { + + private var channel: Channel = _ + private var bootstrap: Bootstrap = _ + private var group: EventLoopGroup = _ + private val sendTimeout = 60 + + def init(): Unit = { + group = new OioEventLoopGroup + bootstrap = new Bootstrap + bootstrap.group(group) + .channel(classOf[OioSocketChannel]) + .option(ChannelOption.SO_KEEPALIVE, java.lang.Boolean.TRUE) + .option(ChannelOption.TCP_NODELAY, java.lang.Boolean.TRUE) + .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, Integer.valueOf(connectTimeout)) + .handler(new FileClientChannelInitializer(handler)) + } + + def connect(host: String, port: Int) { + try { + channel = bootstrap.connect(host, port).sync().channel() + } catch { + case e: InterruptedException => + logWarning("FileClient interrupted while trying to connect", e) + close() + } + } + + def waitForClose(): Unit = { + try { + channel.closeFuture.sync() + } catch { + case e: InterruptedException => + logWarning("FileClient interrupted", e) + } + } + + def sendRequest(file: String): Unit = { + try { + val bSent = channel.writeAndFlush(file + "\r\n").await(sendTimeout, TimeUnit.SECONDS) + if (!bSent) { + throw new RuntimeException("Failed to send") + } + } catch { + case e: InterruptedException => + logError("Error", e) + } + } + + def close(): Unit = { + if (group != null) { + group.shutdownGracefully() + group = null + bootstrap = null + } + } +} diff --git a/core/src/main/java/org/apache/spark/network/netty/FileClientChannelInitializer.java b/core/src/main/scala/org/apache/spark/network/netty/FileClientChannelInitializer.scala similarity index 57% rename from core/src/main/java/org/apache/spark/network/netty/FileClientChannelInitializer.java rename to core/src/main/scala/org/apache/spark/network/netty/FileClientChannelInitializer.scala index 264cf97d0209f..f4261c13f70a8 100644 --- a/core/src/main/java/org/apache/spark/network/netty/FileClientChannelInitializer.java +++ b/core/src/main/scala/org/apache/spark/network/netty/FileClientChannelInitializer.scala @@ -15,25 +15,17 @@ * limitations under the License. */ -package org.apache.spark.network.netty; +package org.apache.spark.network.netty -import io.netty.channel.ChannelInitializer; -import io.netty.channel.socket.SocketChannel; -import io.netty.handler.codec.string.StringEncoder; +import io.netty.channel.ChannelInitializer +import io.netty.channel.socket.SocketChannel +import io.netty.handler.codec.string.StringEncoder -class FileClientChannelInitializer extends ChannelInitializer { - private final FileClientHandler fhandler; +class FileClientChannelInitializer(handler: FileClientHandler) + extends ChannelInitializer[SocketChannel] { - FileClientChannelInitializer(FileClientHandler handler) { - fhandler = handler; - } - - @Override - public void initChannel(SocketChannel channel) { - // file no more than 2G - channel.pipeline() - .addLast("encoder", new StringEncoder()) - .addLast("handler", fhandler); + def initChannel(channel: SocketChannel) { + channel.pipeline.addLast("encoder", new StringEncoder).addLast("handler", handler) } } diff --git a/core/src/main/java/org/apache/spark/network/netty/FileClientHandler.java b/core/src/main/scala/org/apache/spark/network/netty/FileClientHandler.scala similarity index 51% rename from core/src/main/java/org/apache/spark/network/netty/FileClientHandler.java rename to core/src/main/scala/org/apache/spark/network/netty/FileClientHandler.scala index 63d3d927255f9..017302ec7d33d 100644 --- a/core/src/main/java/org/apache/spark/network/netty/FileClientHandler.java +++ b/core/src/main/scala/org/apache/spark/network/netty/FileClientHandler.scala @@ -15,41 +15,36 @@ * limitations under the License. */ -package org.apache.spark.network.netty; +package org.apache.spark.network.netty -import io.netty.buffer.ByteBuf; -import io.netty.channel.ChannelHandlerContext; -import io.netty.channel.SimpleChannelInboundHandler; +import io.netty.buffer.ByteBuf +import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler} -import org.apache.spark.storage.BlockId; +import org.apache.spark.storage.BlockId -abstract class FileClientHandler extends SimpleChannelInboundHandler { - private FileHeader currentHeader = null; +abstract class FileClientHandler extends SimpleChannelInboundHandler[ByteBuf] { - private volatile boolean handlerCalled = false; + private var currentHeader: FileHeader = null - public boolean isComplete() { - return handlerCalled; - } + @volatile + private var handlerCalled: Boolean = false + + def isComplete: Boolean = handlerCalled + + def handle(ctx: ChannelHandlerContext, in: ByteBuf, header: FileHeader) - public abstract void handle(ChannelHandlerContext ctx, ByteBuf in, FileHeader header); - public abstract void handleError(BlockId blockId); + def handleError(blockId: BlockId) - @Override - public void channelRead0(ChannelHandlerContext ctx, ByteBuf in) { - // get header - if (currentHeader == null && in.readableBytes() >= FileHeader.HEADER_SIZE()) { - currentHeader = FileHeader.create(in.readBytes(FileHeader.HEADER_SIZE())); + override def channelRead0(ctx: ChannelHandlerContext, in: ByteBuf) { + if (currentHeader == null && in.readableBytes >= FileHeader.HEADER_SIZE) { + currentHeader = FileHeader.create(in.readBytes(FileHeader.HEADER_SIZE)) } - // get file - if(in.readableBytes() >= currentHeader.fileLen()) { - handle(ctx, in, currentHeader); - handlerCalled = true; - currentHeader = null; - ctx.close(); + if (in.readableBytes >= currentHeader.fileLen) { + handle(ctx, in, currentHeader) + handlerCalled = true + currentHeader = null + ctx.close() } } - } - diff --git a/core/src/main/scala/org/apache/spark/network/netty/FileHeader.scala b/core/src/main/scala/org/apache/spark/network/netty/FileHeader.scala index 136c1912045aa..607e560ff277f 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/FileHeader.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/FileHeader.scala @@ -26,7 +26,7 @@ private[spark] class FileHeader ( val fileLen: Int, val blockId: BlockId) extends Logging { - lazy val buffer = { + lazy val buffer: ByteBuf = { val buf = Unpooled.buffer() buf.capacity(FileHeader.HEADER_SIZE) buf.writeInt(fileLen) @@ -62,11 +62,10 @@ private[spark] object FileHeader { new FileHeader(length, blockId) } - def main (args:Array[String]) { + def main(args:Array[String]) { val header = new FileHeader(25, TestBlockId("my_block")) val buf = header.buffer val newHeader = FileHeader.create(buf) System.out.println("id=" + newHeader.blockId + ",size=" + newHeader.fileLen) } } - diff --git a/core/src/main/scala/org/apache/spark/network/netty/FileServer.scala b/core/src/main/scala/org/apache/spark/network/netty/FileServer.scala new file mode 100644 index 0000000000000..dff77950659af --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/netty/FileServer.scala @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty + +import java.net.InetSocketAddress + +import io.netty.bootstrap.ServerBootstrap +import io.netty.channel.{ChannelFuture, ChannelOption, EventLoopGroup} +import io.netty.channel.oio.OioEventLoopGroup +import io.netty.channel.socket.oio.OioServerSocketChannel + +import org.apache.spark.Logging + +/** + * Server that accept the path of a file an echo back its content. + */ +class FileServer(pResolver: PathResolver, private var port: Int) extends Logging { + + private val addr: InetSocketAddress = new InetSocketAddress(port) + private var bossGroup: EventLoopGroup = new OioEventLoopGroup + private var workerGroup: EventLoopGroup = new OioEventLoopGroup + + private var channelFuture: ChannelFuture = { + val bootstrap = new ServerBootstrap + bootstrap.group(bossGroup, workerGroup) + .channel(classOf[OioServerSocketChannel]) + .option(ChannelOption.SO_BACKLOG, java.lang.Integer.valueOf(100)) + .option(ChannelOption.SO_RCVBUF, java.lang.Integer.valueOf(1500)) + .childHandler(new FileServerChannelInitializer(pResolver)) + bootstrap.bind(addr) + } + + try { + val boundAddress = channelFuture.sync.channel.localAddress.asInstanceOf[InetSocketAddress] + port = boundAddress.getPort + } catch { + case ie: InterruptedException => + port = 0 + } + + /** Start the file server asynchronously in a new thread. */ + def start(): Unit = { + val blockingThread: Thread = new Thread { + override def run(): Unit = { + try { + channelFuture.channel.closeFuture.sync + logInfo("FileServer exiting") + } catch { + case e: InterruptedException => + logError("File server start got interrupted", e) + } + // NOTE: bootstrap is shutdown in stop() + } + } + blockingThread.setDaemon(true) + blockingThread.start() + } + + def getPort: Int = port + + def stop(): Unit = { + if (channelFuture != null) { + channelFuture.channel().close().awaitUninterruptibly() + channelFuture = null + } + if (bossGroup != null) { + bossGroup.shutdownGracefully() + bossGroup = null + } + if (workerGroup != null) { + workerGroup.shutdownGracefully() + workerGroup = null + } + } +} + diff --git a/core/src/main/java/org/apache/spark/network/netty/FileServerChannelInitializer.java b/core/src/main/scala/org/apache/spark/network/netty/FileServerChannelInitializer.scala similarity index 54% rename from core/src/main/java/org/apache/spark/network/netty/FileServerChannelInitializer.java rename to core/src/main/scala/org/apache/spark/network/netty/FileServerChannelInitializer.scala index 46efec8f8d963..aaa2f913d0269 100644 --- a/core/src/main/java/org/apache/spark/network/netty/FileServerChannelInitializer.java +++ b/core/src/main/scala/org/apache/spark/network/netty/FileServerChannelInitializer.scala @@ -15,27 +15,20 @@ * limitations under the License. */ -package org.apache.spark.network.netty; +package org.apache.spark.network.netty -import io.netty.channel.ChannelInitializer; -import io.netty.channel.socket.SocketChannel; -import io.netty.handler.codec.DelimiterBasedFrameDecoder; -import io.netty.handler.codec.Delimiters; -import io.netty.handler.codec.string.StringDecoder; +import io.netty.channel.ChannelInitializer +import io.netty.channel.socket.SocketChannel +import io.netty.handler.codec.{DelimiterBasedFrameDecoder, Delimiters} +import io.netty.handler.codec.string.StringDecoder -class FileServerChannelInitializer extends ChannelInitializer { +class FileServerChannelInitializer(pResolver: PathResolver) + extends ChannelInitializer[SocketChannel] { - private final PathResolver pResolver; - - FileServerChannelInitializer(PathResolver pResolver) { - this.pResolver = pResolver; - } - - @Override - public void initChannel(SocketChannel channel) { - channel.pipeline() - .addLast("framer", new DelimiterBasedFrameDecoder(8192, Delimiters.lineDelimiter())) - .addLast("stringDecoder", new StringDecoder()) - .addLast("handler", new FileServerHandler(pResolver)); + override def initChannel(channel: SocketChannel): Unit = { + channel.pipeline + .addLast("framer", new DelimiterBasedFrameDecoder(8192, Delimiters.lineDelimiter : _*)) + .addLast("stringDecoder", new StringDecoder) + .addLast("handler", new FileServerHandler(pResolver)) } } diff --git a/core/src/main/scala/org/apache/spark/network/netty/FileServerHandler.scala b/core/src/main/scala/org/apache/spark/network/netty/FileServerHandler.scala new file mode 100644 index 0000000000000..96f60b2883ad9 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/netty/FileServerHandler.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty + +import java.io.FileInputStream + +import io.netty.channel.{DefaultFileRegion, ChannelHandlerContext, SimpleChannelInboundHandler} + +import org.apache.spark.Logging +import org.apache.spark.storage.{BlockId, FileSegment} + + +class FileServerHandler(pResolver: PathResolver) + extends SimpleChannelInboundHandler[String] with Logging { + + override def channelRead0(ctx: ChannelHandlerContext, blockIdString: String): Unit = { + val blockId: BlockId = BlockId(blockIdString) + val fileSegment: FileSegment = pResolver.getBlockLocation(blockId) + if (fileSegment == null) { + return + } + val file = fileSegment.file + if (file.exists) { + if (!file.isFile) { + ctx.write(new FileHeader(0, blockId).buffer) + ctx.flush() + return + } + val length: Long = fileSegment.length + if (length > Integer.MAX_VALUE || length <= 0) { + ctx.write(new FileHeader(0, blockId).buffer) + ctx.flush() + return + } + ctx.write(new FileHeader(length.toInt, blockId).buffer) + try { + val channel = new FileInputStream(file).getChannel + ctx.write(new DefaultFileRegion(channel, fileSegment.offset, fileSegment.length)) + } catch { + case e: Exception => + logError("Exception: ", e) + } + } else { + ctx.write(new FileHeader(0, blockId).buffer) + } + ctx.flush() + } + + override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = { + logError("Exception: ", cause) + ctx.close() + } +} diff --git a/core/src/main/java/org/apache/spark/network/netty/PathResolver.java b/core/src/main/scala/org/apache/spark/network/netty/PathResolver.scala old mode 100755 new mode 100644 similarity index 80% rename from core/src/main/java/org/apache/spark/network/netty/PathResolver.java rename to core/src/main/scala/org/apache/spark/network/netty/PathResolver.scala index 7ad8d03efbadc..0d7695072a7b1 --- a/core/src/main/java/org/apache/spark/network/netty/PathResolver.java +++ b/core/src/main/scala/org/apache/spark/network/netty/PathResolver.scala @@ -15,12 +15,11 @@ * limitations under the License. */ -package org.apache.spark.network.netty; +package org.apache.spark.network.netty -import org.apache.spark.storage.BlockId; -import org.apache.spark.storage.FileSegment; +import org.apache.spark.storage.{BlockId, FileSegment} -public interface PathResolver { +trait PathResolver { /** Get the file segment in which the given block resides. */ - FileSegment getBlockLocation(BlockId blockId); + def getBlockLocation(blockId: BlockId): FileSegment } diff --git a/core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala b/core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala index 7ef7aecc6a9fb..95958e30f7eeb 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala @@ -32,7 +32,7 @@ private[spark] class ShuffleSender(portIn: Int, val pResolver: PathResolver) ext server.stop() } - def port: Int = server.getPort() + def port: Int = server.getPort } From 3f23d2a38c3b6559902bc2ab6975ff6b0bec875e Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 14 Aug 2014 19:01:33 -0700 Subject: [PATCH 145/231] [SPARK-2468] Netty based block server / client module This is a rewrite of the original Netty module that was added about 1.5 years ago. The old code was turned off by default and didn't really work because it lacked a frame decoder (only worked with very very small blocks). For this pull request, I tried to make the changes non-instrusive to the rest of Spark. I only added an init and shutdown to BlockManager/DiskBlockManager, and a bunch of comments to help me understand the existing code base. Compared with the old Netty module, this one features: - It appears to work :) - SPARK-2941: option to specicy nio vs oio vs epoll for channel/transport. By default nio is used. (Not using Epoll yet because I have found some bugs with its implementation) - SPARK-2943: options to specify send buf and receive buf for users who want to do hyper tuning - SPARK-2942: io errors are reported from server to client (the protocol uses negative length to indicate error) - SPARK-2940: fetching multiple blocks in a single request to reduce syscalls - SPARK-2959: clients share a single thread pool - SPARK-2990: use PooledByteBufAllocator to reduce GC (basically a Netty managed pool of buffers with jmalloc) - SPARK-2625: added fetchWaitTime metric and fixed thread-safety issue in metrics update. - SPARK-2367: bump Netty version to 4.0.21.Final to address an Epoll bug (https://groups.google.com/forum/#!topic/netty/O7m-HxCJpCA) Compared with the existing communication manager, this one features: - IMO it is substantially easier to understand - zero-copy send for the server for on-disk blocks - one-copy receive (due to a frame decoder) - don't quote me on this, but I think a lot less sys calls - SPARK-2990: use PooledByteBufAllocator to reduce GC (basically a Netty managed pool of buffers with jmalloc) - SPARK-2941: option to specicy nio vs oio vs epoll for channel/transport. By default nio is used. (Not using Epoll yet because I have found some bugs with its implementation) - SPARK-2943: options to specify send buf and receive buf for users who want to do hyper tuning TODOs before it can fully replace the existing ConnectionManager, if that ever happens (most of them should probably be done in separate PRs since this needs to be turned on explicitly) - [x] Basic test cases - [ ] More unit/integration tests for failures - [ ] Performance analysis - [ ] Support client connection reuse so we don't need to keep opening new connections (not sure how useful this would be) - [ ] Support putting blocks in addition to fetching blocks (i.e. two way transfer) - [x] Support serving non-disk blocks - [ ] Support SASL authentication For a more comprehensive list, see https://issues.apache.org/jira/browse/SPARK-2468 Thanks to @coderplay for peer coding with me on a Sunday. Author: Reynold Xin Closes #1907 from rxin/netty and squashes the following commits: f921421 [Reynold Xin] Upgrade Netty to 4.0.22.Final to fix another Epoll bug. 4b174ca [Reynold Xin] Shivaram's code review comment. 4a3dfe7 [Reynold Xin] Switched to nio for default (instead of epoll on Linux). 56bfb9d [Reynold Xin] Bump Netty version to 4.0.21.Final for some bug fixes. b443a4b [Reynold Xin] Added debug message to help debug Jenkins failures. 57fc4d7 [Reynold Xin] Added test cases for BlockHeaderEncoder and BlockFetchingClientHandlerSuite. 22623e9 [Reynold Xin] Added exception handling and test case for BlockServerHandler and BlockFetchingClientHandler. 6550dd7 [Reynold Xin] Fixed block mgr init bug. 60c2edf [Reynold Xin] Beefed up server/client integration tests. 38d88d5 [Reynold Xin] Added missing test files. 6ce3f3c [Reynold Xin] Added some basic test cases. 47f7ce0 [Reynold Xin] Created server and client packages and moved files there. b16f412 [Reynold Xin] Added commit count. f13022d [Reynold Xin] Remove unused clone() in BlockFetcherIterator. c57d68c [Reynold Xin] Added back missing files. 842dfa7 [Reynold Xin] Made everything work with proper reference counting. 3fae001 [Reynold Xin] Connected the new netty network module with rest of Spark. 1a8f6d4 [Reynold Xin] Completed protocol documentation. 2951478 [Reynold Xin] New Netty implementation. cc7843d [Reynold Xin] Basic skeleton. (cherry picked from commit 3a8b68b7353fea50245686903b308fa9eb52cb51) Signed-off-by: Reynold Xin --- .../spark/network/netty/FileClient.scala | 85 - .../network/netty/FileClientHandler.scala | 50 - .../spark/network/netty/FileHeader.scala | 71 - .../spark/network/netty/FileServer.scala | 91 -- .../network/netty/FileServerHandler.scala | 68 - .../spark/network/netty/NettyConfig.scala | 59 + .../spark/network/netty/ShuffleCopier.scala | 118 -- .../spark/network/netty/ShuffleSender.scala | 71 - .../netty/client/BlockFetchingClient.scala | 135 ++ .../client/BlockFetchingClientFactory.scala | 99 ++ .../client/BlockFetchingClientHandler.scala | 63 + .../netty/client/LazyInitIterator.scala | 44 + .../netty/client/ReferenceCountedBuffer.scala | 47 + .../network/netty/server/BlockHeader.scala | 32 + .../netty/server/BlockHeaderEncoder.scala | 47 + .../network/netty/server/BlockServer.scala | 162 ++ .../BlockServerChannelInitializer.scala} | 22 +- .../netty/server/BlockServerHandler.scala | 140 ++ .../BlockDataProvider.scala} | 21 +- .../spark/storage/BlockFetcherIterator.scala | 138 +- .../apache/spark/storage/BlockManager.scala | 49 +- .../storage/BlockNotFoundException.scala | 21 + .../spark/storage/DiskBlockManager.scala | 13 +- core/src/test/resources/netty-test-file.txt | 1379 +++++++++++++++++ .../netty/ServerClientIntegrationSuite.scala | 158 ++ .../BlockFetchingClientHandlerSuite.scala | 87 ++ .../server/BlockHeaderEncoderSuite.scala | 64 + .../server/BlockServerHandlerSuite.scala | 101 ++ pom.xml | 2 +- 29 files changed, 2770 insertions(+), 667 deletions(-) delete mode 100644 core/src/main/scala/org/apache/spark/network/netty/FileClient.scala delete mode 100644 core/src/main/scala/org/apache/spark/network/netty/FileClientHandler.scala delete mode 100644 core/src/main/scala/org/apache/spark/network/netty/FileHeader.scala delete mode 100644 core/src/main/scala/org/apache/spark/network/netty/FileServer.scala delete mode 100644 core/src/main/scala/org/apache/spark/network/netty/FileServerHandler.scala create mode 100644 core/src/main/scala/org/apache/spark/network/netty/NettyConfig.scala delete mode 100644 core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala delete mode 100644 core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala create mode 100644 core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClient.scala create mode 100644 core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClientFactory.scala create mode 100644 core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClientHandler.scala create mode 100644 core/src/main/scala/org/apache/spark/network/netty/client/LazyInitIterator.scala create mode 100644 core/src/main/scala/org/apache/spark/network/netty/client/ReferenceCountedBuffer.scala create mode 100644 core/src/main/scala/org/apache/spark/network/netty/server/BlockHeader.scala create mode 100644 core/src/main/scala/org/apache/spark/network/netty/server/BlockHeaderEncoder.scala create mode 100644 core/src/main/scala/org/apache/spark/network/netty/server/BlockServer.scala rename core/src/main/scala/org/apache/spark/network/netty/{FileServerChannelInitializer.scala => server/BlockServerChannelInitializer.scala} (58%) create mode 100644 core/src/main/scala/org/apache/spark/network/netty/server/BlockServerHandler.scala rename core/src/main/scala/org/apache/spark/{network/netty/FileClientChannelInitializer.scala => storage/BlockDataProvider.scala} (65%) create mode 100644 core/src/main/scala/org/apache/spark/storage/BlockNotFoundException.scala create mode 100644 core/src/test/resources/netty-test-file.txt create mode 100644 core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/network/netty/client/BlockFetchingClientHandlerSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/network/netty/server/BlockHeaderEncoderSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/network/netty/server/BlockServerHandlerSuite.scala diff --git a/core/src/main/scala/org/apache/spark/network/netty/FileClient.scala b/core/src/main/scala/org/apache/spark/network/netty/FileClient.scala deleted file mode 100644 index c6d35f73db545..0000000000000 --- a/core/src/main/scala/org/apache/spark/network/netty/FileClient.scala +++ /dev/null @@ -1,85 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty - -import java.util.concurrent.TimeUnit - -import io.netty.bootstrap.Bootstrap -import io.netty.channel.{Channel, ChannelOption, EventLoopGroup} -import io.netty.channel.oio.OioEventLoopGroup -import io.netty.channel.socket.oio.OioSocketChannel - -import org.apache.spark.Logging - -class FileClient(handler: FileClientHandler, connectTimeout: Int) extends Logging { - - private var channel: Channel = _ - private var bootstrap: Bootstrap = _ - private var group: EventLoopGroup = _ - private val sendTimeout = 60 - - def init(): Unit = { - group = new OioEventLoopGroup - bootstrap = new Bootstrap - bootstrap.group(group) - .channel(classOf[OioSocketChannel]) - .option(ChannelOption.SO_KEEPALIVE, java.lang.Boolean.TRUE) - .option(ChannelOption.TCP_NODELAY, java.lang.Boolean.TRUE) - .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, Integer.valueOf(connectTimeout)) - .handler(new FileClientChannelInitializer(handler)) - } - - def connect(host: String, port: Int) { - try { - channel = bootstrap.connect(host, port).sync().channel() - } catch { - case e: InterruptedException => - logWarning("FileClient interrupted while trying to connect", e) - close() - } - } - - def waitForClose(): Unit = { - try { - channel.closeFuture.sync() - } catch { - case e: InterruptedException => - logWarning("FileClient interrupted", e) - } - } - - def sendRequest(file: String): Unit = { - try { - val bSent = channel.writeAndFlush(file + "\r\n").await(sendTimeout, TimeUnit.SECONDS) - if (!bSent) { - throw new RuntimeException("Failed to send") - } - } catch { - case e: InterruptedException => - logError("Error", e) - } - } - - def close(): Unit = { - if (group != null) { - group.shutdownGracefully() - group = null - bootstrap = null - } - } -} diff --git a/core/src/main/scala/org/apache/spark/network/netty/FileClientHandler.scala b/core/src/main/scala/org/apache/spark/network/netty/FileClientHandler.scala deleted file mode 100644 index 017302ec7d33d..0000000000000 --- a/core/src/main/scala/org/apache/spark/network/netty/FileClientHandler.scala +++ /dev/null @@ -1,50 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty - -import io.netty.buffer.ByteBuf -import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler} - -import org.apache.spark.storage.BlockId - - -abstract class FileClientHandler extends SimpleChannelInboundHandler[ByteBuf] { - - private var currentHeader: FileHeader = null - - @volatile - private var handlerCalled: Boolean = false - - def isComplete: Boolean = handlerCalled - - def handle(ctx: ChannelHandlerContext, in: ByteBuf, header: FileHeader) - - def handleError(blockId: BlockId) - - override def channelRead0(ctx: ChannelHandlerContext, in: ByteBuf) { - if (currentHeader == null && in.readableBytes >= FileHeader.HEADER_SIZE) { - currentHeader = FileHeader.create(in.readBytes(FileHeader.HEADER_SIZE)) - } - if (in.readableBytes >= currentHeader.fileLen) { - handle(ctx, in, currentHeader) - handlerCalled = true - currentHeader = null - ctx.close() - } - } -} diff --git a/core/src/main/scala/org/apache/spark/network/netty/FileHeader.scala b/core/src/main/scala/org/apache/spark/network/netty/FileHeader.scala deleted file mode 100644 index 607e560ff277f..0000000000000 --- a/core/src/main/scala/org/apache/spark/network/netty/FileHeader.scala +++ /dev/null @@ -1,71 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty - -import io.netty.buffer._ - -import org.apache.spark.Logging -import org.apache.spark.storage.{BlockId, TestBlockId} - -private[spark] class FileHeader ( - val fileLen: Int, - val blockId: BlockId) extends Logging { - - lazy val buffer: ByteBuf = { - val buf = Unpooled.buffer() - buf.capacity(FileHeader.HEADER_SIZE) - buf.writeInt(fileLen) - buf.writeInt(blockId.name.length) - blockId.name.foreach((x: Char) => buf.writeByte(x)) - // padding the rest of header - if (FileHeader.HEADER_SIZE - buf.readableBytes > 0 ) { - buf.writeZero(FileHeader.HEADER_SIZE - buf.readableBytes) - } else { - throw new Exception("too long header " + buf.readableBytes) - logInfo("too long header") - } - buf - } - -} - -private[spark] object FileHeader { - - val HEADER_SIZE = 40 - - def getFileLenOffset = 0 - def getFileLenSize = Integer.SIZE/8 - - def create(buf: ByteBuf): FileHeader = { - val length = buf.readInt - val idLength = buf.readInt - val idBuilder = new StringBuilder(idLength) - for (i <- 1 to idLength) { - idBuilder += buf.readByte().asInstanceOf[Char] - } - val blockId = BlockId(idBuilder.toString()) - new FileHeader(length, blockId) - } - - def main(args:Array[String]) { - val header = new FileHeader(25, TestBlockId("my_block")) - val buf = header.buffer - val newHeader = FileHeader.create(buf) - System.out.println("id=" + newHeader.blockId + ",size=" + newHeader.fileLen) - } -} diff --git a/core/src/main/scala/org/apache/spark/network/netty/FileServer.scala b/core/src/main/scala/org/apache/spark/network/netty/FileServer.scala deleted file mode 100644 index dff77950659af..0000000000000 --- a/core/src/main/scala/org/apache/spark/network/netty/FileServer.scala +++ /dev/null @@ -1,91 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty - -import java.net.InetSocketAddress - -import io.netty.bootstrap.ServerBootstrap -import io.netty.channel.{ChannelFuture, ChannelOption, EventLoopGroup} -import io.netty.channel.oio.OioEventLoopGroup -import io.netty.channel.socket.oio.OioServerSocketChannel - -import org.apache.spark.Logging - -/** - * Server that accept the path of a file an echo back its content. - */ -class FileServer(pResolver: PathResolver, private var port: Int) extends Logging { - - private val addr: InetSocketAddress = new InetSocketAddress(port) - private var bossGroup: EventLoopGroup = new OioEventLoopGroup - private var workerGroup: EventLoopGroup = new OioEventLoopGroup - - private var channelFuture: ChannelFuture = { - val bootstrap = new ServerBootstrap - bootstrap.group(bossGroup, workerGroup) - .channel(classOf[OioServerSocketChannel]) - .option(ChannelOption.SO_BACKLOG, java.lang.Integer.valueOf(100)) - .option(ChannelOption.SO_RCVBUF, java.lang.Integer.valueOf(1500)) - .childHandler(new FileServerChannelInitializer(pResolver)) - bootstrap.bind(addr) - } - - try { - val boundAddress = channelFuture.sync.channel.localAddress.asInstanceOf[InetSocketAddress] - port = boundAddress.getPort - } catch { - case ie: InterruptedException => - port = 0 - } - - /** Start the file server asynchronously in a new thread. */ - def start(): Unit = { - val blockingThread: Thread = new Thread { - override def run(): Unit = { - try { - channelFuture.channel.closeFuture.sync - logInfo("FileServer exiting") - } catch { - case e: InterruptedException => - logError("File server start got interrupted", e) - } - // NOTE: bootstrap is shutdown in stop() - } - } - blockingThread.setDaemon(true) - blockingThread.start() - } - - def getPort: Int = port - - def stop(): Unit = { - if (channelFuture != null) { - channelFuture.channel().close().awaitUninterruptibly() - channelFuture = null - } - if (bossGroup != null) { - bossGroup.shutdownGracefully() - bossGroup = null - } - if (workerGroup != null) { - workerGroup.shutdownGracefully() - workerGroup = null - } - } -} - diff --git a/core/src/main/scala/org/apache/spark/network/netty/FileServerHandler.scala b/core/src/main/scala/org/apache/spark/network/netty/FileServerHandler.scala deleted file mode 100644 index 96f60b2883ad9..0000000000000 --- a/core/src/main/scala/org/apache/spark/network/netty/FileServerHandler.scala +++ /dev/null @@ -1,68 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty - -import java.io.FileInputStream - -import io.netty.channel.{DefaultFileRegion, ChannelHandlerContext, SimpleChannelInboundHandler} - -import org.apache.spark.Logging -import org.apache.spark.storage.{BlockId, FileSegment} - - -class FileServerHandler(pResolver: PathResolver) - extends SimpleChannelInboundHandler[String] with Logging { - - override def channelRead0(ctx: ChannelHandlerContext, blockIdString: String): Unit = { - val blockId: BlockId = BlockId(blockIdString) - val fileSegment: FileSegment = pResolver.getBlockLocation(blockId) - if (fileSegment == null) { - return - } - val file = fileSegment.file - if (file.exists) { - if (!file.isFile) { - ctx.write(new FileHeader(0, blockId).buffer) - ctx.flush() - return - } - val length: Long = fileSegment.length - if (length > Integer.MAX_VALUE || length <= 0) { - ctx.write(new FileHeader(0, blockId).buffer) - ctx.flush() - return - } - ctx.write(new FileHeader(length.toInt, blockId).buffer) - try { - val channel = new FileInputStream(file).getChannel - ctx.write(new DefaultFileRegion(channel, fileSegment.offset, fileSegment.length)) - } catch { - case e: Exception => - logError("Exception: ", e) - } - } else { - ctx.write(new FileHeader(0, blockId).buffer) - } - ctx.flush() - } - - override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = { - logError("Exception: ", cause) - ctx.close() - } -} diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyConfig.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyConfig.scala new file mode 100644 index 0000000000000..b5870152c5a64 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyConfig.scala @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty + +import org.apache.spark.SparkConf + +/** + * A central location that tracks all the settings we exposed to users. + */ +private[spark] +class NettyConfig(conf: SparkConf) { + + /** Port the server listens on. Default to a random port. */ + private[netty] val serverPort = conf.getInt("spark.shuffle.io.port", 0) + + /** IO mode: nio, oio, epoll, or auto (try epoll first and then nio). */ + private[netty] val ioMode = conf.get("spark.shuffle.io.mode", "nio").toLowerCase + + /** Connect timeout in secs. Default 60 secs. */ + private[netty] val connectTimeoutMs = conf.getInt("spark.shuffle.io.connectionTimeout", 60) * 1000 + + /** + * Percentage of the desired amount of time spent for I/O in the child event loops. + * Only applicable in nio and epoll. + */ + private[netty] val ioRatio = conf.getInt("spark.shuffle.io.netty.ioRatio", 80) + + /** Requested maximum length of the queue of incoming connections. */ + private[netty] val backLog: Option[Int] = conf.getOption("spark.shuffle.io.backLog").map(_.toInt) + + /** + * Receive buffer size (SO_RCVBUF). + * Note: the optimal size for receive buffer and send buffer should be + * latency * network_bandwidth. + * Assuming latency = 1ms, network_bandwidth = 10Gbps + * buffer size should be ~ 1.25MB + */ + private[netty] val receiveBuf: Option[Int] = + conf.getOption("spark.shuffle.io.sendBuffer").map(_.toInt) + + /** Send buffer size (SO_SNDBUF). */ + private[netty] val sendBuf: Option[Int] = + conf.getOption("spark.shuffle.io.sendBuffer").map(_.toInt) +} diff --git a/core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala b/core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala deleted file mode 100644 index e7b2855e1ec91..0000000000000 --- a/core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala +++ /dev/null @@ -1,118 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty - -import java.util.concurrent.Executors - -import scala.collection.JavaConverters._ - -import io.netty.buffer.ByteBuf -import io.netty.channel.ChannelHandlerContext -import io.netty.util.CharsetUtil - -import org.apache.spark.{Logging, SparkConf} -import org.apache.spark.network.ConnectionManagerId -import org.apache.spark.storage.BlockId - -private[spark] class ShuffleCopier(conf: SparkConf) extends Logging { - - def getBlock(host: String, port: Int, blockId: BlockId, - resultCollectCallback: (BlockId, Long, ByteBuf) => Unit) { - - val handler = new ShuffleCopier.ShuffleClientHandler(resultCollectCallback) - val connectTimeout = conf.getInt("spark.shuffle.netty.connect.timeout", 60000) - val fc = new FileClient(handler, connectTimeout) - - try { - fc.init() - fc.connect(host, port) - fc.sendRequest(blockId.name) - fc.waitForClose() - fc.close() - } catch { - // Handle any socket-related exceptions in FileClient - case e: Exception => { - logError("Shuffle copy of block " + blockId + " from " + host + ":" + port + " failed", e) - handler.handleError(blockId) - } - } - } - - def getBlock(cmId: ConnectionManagerId, blockId: BlockId, - resultCollectCallback: (BlockId, Long, ByteBuf) => Unit) { - getBlock(cmId.host, cmId.port, blockId, resultCollectCallback) - } - - def getBlocks(cmId: ConnectionManagerId, - blocks: Seq[(BlockId, Long)], - resultCollectCallback: (BlockId, Long, ByteBuf) => Unit) { - - for ((blockId, size) <- blocks) { - getBlock(cmId, blockId, resultCollectCallback) - } - } -} - - -private[spark] object ShuffleCopier extends Logging { - - private class ShuffleClientHandler(resultCollectCallBack: (BlockId, Long, ByteBuf) => Unit) - extends FileClientHandler with Logging { - - override def handle(ctx: ChannelHandlerContext, in: ByteBuf, header: FileHeader) { - logDebug("Received Block: " + header.blockId + " (" + header.fileLen + "B)") - resultCollectCallBack(header.blockId, header.fileLen.toLong, in.readBytes(header.fileLen)) - } - - override def handleError(blockId: BlockId) { - if (!isComplete) { - resultCollectCallBack(blockId, -1, null) - } - } - } - - def echoResultCollectCallBack(blockId: BlockId, size: Long, content: ByteBuf) { - if (size != -1) { - logInfo("File: " + blockId + " content is : \" " + content.toString(CharsetUtil.UTF_8) + "\"") - } - } - - def main(args: Array[String]) { - if (args.length < 3) { - System.err.println("Usage: ShuffleCopier ") - System.exit(1) - } - val host = args(0) - val port = args(1).toInt - val blockId = BlockId(args(2)) - val threads = if (args.length > 3) args(3).toInt else 10 - - val copiers = Executors.newFixedThreadPool(80) - val tasks = (for (i <- Range(0, threads)) yield { - Executors.callable(new Runnable() { - def run() { - val copier = new ShuffleCopier(new SparkConf) - copier.getBlock(host, port, blockId, echoResultCollectCallBack) - } - }) - }).asJava - copiers.invokeAll(tasks) - copiers.shutdown() - System.exit(0) - } -} diff --git a/core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala b/core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala deleted file mode 100644 index 95958e30f7eeb..0000000000000 --- a/core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala +++ /dev/null @@ -1,71 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty - -import java.io.File - -import org.apache.spark.Logging -import org.apache.spark.util.Utils -import org.apache.spark.storage.{BlockId, FileSegment} - -private[spark] class ShuffleSender(portIn: Int, val pResolver: PathResolver) extends Logging { - - val server = new FileServer(pResolver, portIn) - server.start() - - def stop() { - server.stop() - } - - def port: Int = server.getPort -} - - -/** - * An application for testing the shuffle sender as a standalone program. - */ -private[spark] object ShuffleSender { - - def main(args: Array[String]) { - if (args.length < 3) { - System.err.println( - "Usage: ShuffleSender ") - System.exit(1) - } - - val port = args(0).toInt - val subDirsPerLocalDir = args(1).toInt - val localDirs = args.drop(2).map(new File(_)) - - val pResovler = new PathResolver { - override def getBlockLocation(blockId: BlockId): FileSegment = { - if (!blockId.isShuffle) { - throw new Exception("Block " + blockId + " is not a shuffle block") - } - // Figure out which local directory it hashes to, and which subdirectory in that - val hash = Utils.nonNegativeHash(blockId) - val dirId = hash % localDirs.length - val subDirId = (hash / localDirs.length) % subDirsPerLocalDir - val subDir = new File(localDirs(dirId), "%02x".format(subDirId)) - val file = new File(subDir, blockId.name) - new FileSegment(file, 0, file.length()) - } - } - val sender = new ShuffleSender(port, pResovler) - } -} diff --git a/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClient.scala b/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClient.scala new file mode 100644 index 0000000000000..9fed11b75c342 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClient.scala @@ -0,0 +1,135 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty.client + +import java.util.concurrent.TimeoutException + +import io.netty.bootstrap.Bootstrap +import io.netty.buffer.PooledByteBufAllocator +import io.netty.channel.socket.SocketChannel +import io.netty.channel.{ChannelFutureListener, ChannelFuture, ChannelInitializer, ChannelOption} +import io.netty.handler.codec.LengthFieldBasedFrameDecoder +import io.netty.handler.codec.string.StringEncoder +import io.netty.util.CharsetUtil + +import org.apache.spark.Logging + +/** + * Client for fetching data blocks from [[org.apache.spark.network.netty.server.BlockServer]]. + * Use [[BlockFetchingClientFactory]] to instantiate this client. + * + * The constructor blocks until a connection is successfully established. + * + * See [[org.apache.spark.network.netty.server.BlockServer]] for client/server protocol. + * + * Concurrency: [[BlockFetchingClient]] is not thread safe and should not be shared. + */ +@throws[TimeoutException] +private[spark] +class BlockFetchingClient(factory: BlockFetchingClientFactory, hostname: String, port: Int) + extends Logging { + + val handler = new BlockFetchingClientHandler + + /** Netty Bootstrap for creating the TCP connection. */ + private val bootstrap: Bootstrap = { + val b = new Bootstrap + b.group(factory.workerGroup) + .channel(factory.socketChannelClass) + // Use pooled buffers to reduce temporary buffer allocation + .option(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT) + // Disable Nagle's Algorithm since we don't want packets to wait + .option(ChannelOption.TCP_NODELAY, java.lang.Boolean.TRUE) + .option(ChannelOption.SO_KEEPALIVE, java.lang.Boolean.TRUE) + .option[Integer](ChannelOption.CONNECT_TIMEOUT_MILLIS, factory.conf.connectTimeoutMs) + + b.handler(new ChannelInitializer[SocketChannel] { + override def initChannel(ch: SocketChannel): Unit = { + ch.pipeline + .addLast("encoder", new StringEncoder(CharsetUtil.UTF_8)) + // maxFrameLength = 2G, lengthFieldOffset = 0, lengthFieldLength = 4 + .addLast("framedLengthDecoder", new LengthFieldBasedFrameDecoder(Int.MaxValue, 0, 4)) + .addLast("handler", handler) + } + }) + b + } + + /** Netty ChannelFuture for the connection. */ + private val cf: ChannelFuture = bootstrap.connect(hostname, port) + if (!cf.awaitUninterruptibly(factory.conf.connectTimeoutMs)) { + throw new TimeoutException( + s"Connecting to $hostname:$port timed out (${factory.conf.connectTimeoutMs} ms)") + } + + /** + * Ask the remote server for a sequence of blocks, and execute the callback. + * + * Note that this is asynchronous and returns immediately. Upstream caller should throttle the + * rate of fetching; otherwise we could run out of memory. + * + * @param blockIds sequence of block ids to fetch. + * @param blockFetchSuccessCallback callback function when a block is successfully fetched. + * First argument is the block id, and second argument is the + * raw data in a ByteBuffer. + * @param blockFetchFailureCallback callback function when we failed to fetch any of the blocks. + * First argument is the block id, and second argument is the + * error message. + */ + def fetchBlocks( + blockIds: Seq[String], + blockFetchSuccessCallback: (String, ReferenceCountedBuffer) => Unit, + blockFetchFailureCallback: (String, String) => Unit): Unit = { + // It's best to limit the number of "write" calls since it needs to traverse the whole pipeline. + // It's also best to limit the number of "flush" calls since it requires system calls. + // Let's concatenate the string and then call writeAndFlush once. + // This is also why this implementation might be more efficient than multiple, separate + // fetch block calls. + var startTime: Long = 0 + logTrace { + startTime = System.nanoTime + s"Sending request $blockIds to $hostname:$port" + } + + // TODO: This is not the most elegant way to handle this ... + handler.blockFetchSuccessCallback = blockFetchSuccessCallback + handler.blockFetchFailureCallback = blockFetchFailureCallback + + val writeFuture = cf.channel().writeAndFlush(blockIds.mkString("\n") + "\n") + writeFuture.addListener(new ChannelFutureListener { + override def operationComplete(future: ChannelFuture): Unit = { + if (future.isSuccess) { + logTrace { + val timeTaken = (System.nanoTime - startTime).toDouble / 1000000 + s"Sending request $blockIds to $hostname:$port took $timeTaken ms" + } + } else { + // Fail all blocks. + logError(s"Failed to send request $blockIds to $hostname:$port", future.cause) + blockIds.foreach(blockFetchFailureCallback(_, future.cause.getMessage)) + } + } + }) + } + + def waitForClose(): Unit = { + cf.channel().closeFuture().sync() + } + + def close(): Unit = cf.channel().close() +} diff --git a/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClientFactory.scala b/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClientFactory.scala new file mode 100644 index 0000000000000..2b28402c52b49 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClientFactory.scala @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty.client + +import io.netty.channel.epoll.{EpollEventLoopGroup, EpollSocketChannel} +import io.netty.channel.nio.NioEventLoopGroup +import io.netty.channel.oio.OioEventLoopGroup +import io.netty.channel.socket.nio.NioSocketChannel +import io.netty.channel.socket.oio.OioSocketChannel +import io.netty.channel.{EventLoopGroup, Channel} + +import org.apache.spark.SparkConf +import org.apache.spark.network.netty.NettyConfig +import org.apache.spark.util.Utils + +/** + * Factory for creating [[BlockFetchingClient]] by using createClient. This factory reuses + * the worker thread pool for Netty. + * + * Concurrency: createClient is safe to be called from multiple threads concurrently. + */ +private[spark] +class BlockFetchingClientFactory(val conf: NettyConfig) { + + def this(sparkConf: SparkConf) = this(new NettyConfig(sparkConf)) + + /** A thread factory so the threads are named (for debugging). */ + val threadFactory = Utils.namedThreadFactory("spark-shuffle-client") + + /** The following two are instantiated by the [[init]] method, depending ioMode. */ + var socketChannelClass: Class[_ <: Channel] = _ + var workerGroup: EventLoopGroup = _ + + init() + + /** Initialize [[socketChannelClass]] and [[workerGroup]] based on ioMode. */ + private def init(): Unit = { + def initOio(): Unit = { + socketChannelClass = classOf[OioSocketChannel] + workerGroup = new OioEventLoopGroup(0, threadFactory) + } + def initNio(): Unit = { + socketChannelClass = classOf[NioSocketChannel] + workerGroup = new NioEventLoopGroup(0, threadFactory) + } + def initEpoll(): Unit = { + socketChannelClass = classOf[EpollSocketChannel] + workerGroup = new EpollEventLoopGroup(0, threadFactory) + } + + conf.ioMode match { + case "nio" => initNio() + case "oio" => initOio() + case "epoll" => initEpoll() + case "auto" => + // For auto mode, first try epoll (only available on Linux), then nio. + try { + initEpoll() + } catch { + // TODO: Should we log the throwable? But that always happen on non-Linux systems. + // Perhaps the right thing to do is to check whether the system is Linux, and then only + // call initEpoll on Linux. + case e: Throwable => initNio() + } + } + } + + /** + * Create a new BlockFetchingClient connecting to the given remote host / port. + * + * This blocks until a connection is successfully established. + * + * Concurrency: This method is safe to call from multiple threads. + */ + def createClient(remoteHost: String, remotePort: Int): BlockFetchingClient = { + new BlockFetchingClient(this, remoteHost, remotePort) + } + + def stop(): Unit = { + if (workerGroup != null) { + workerGroup.shutdownGracefully() + } + } +} diff --git a/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClientHandler.scala b/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClientHandler.scala new file mode 100644 index 0000000000000..a1dbf6102c080 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClientHandler.scala @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty.client + +import io.netty.buffer.ByteBuf +import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler} + +import org.apache.spark.Logging + + +/** + * Handler that processes server responses. It uses the protocol documented in + * [[org.apache.spark.network.netty.server.BlockServer]]. + */ +private[client] +class BlockFetchingClientHandler extends SimpleChannelInboundHandler[ByteBuf] with Logging { + + var blockFetchSuccessCallback: (String, ReferenceCountedBuffer) => Unit = _ + var blockFetchFailureCallback: (String, String) => Unit = _ + + override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = { + logError(s"Exception in connection from ${ctx.channel.remoteAddress}", cause) + ctx.close() + } + + override def channelRead0(ctx: ChannelHandlerContext, in: ByteBuf) { + val totalLen = in.readInt() + val blockIdLen = in.readInt() + val blockIdBytes = new Array[Byte](math.abs(blockIdLen)) + in.readBytes(blockIdBytes) + val blockId = new String(blockIdBytes) + val blockSize = totalLen - math.abs(blockIdLen) - 4 + + def server = ctx.channel.remoteAddress.toString + + // blockIdLen is negative when it is an error message. + if (blockIdLen < 0) { + val errorMessageBytes = new Array[Byte](blockSize) + in.readBytes(errorMessageBytes) + val errorMsg = new String(errorMessageBytes) + logTrace(s"Received block $blockId ($blockSize B) with error $errorMsg from $server") + blockFetchFailureCallback(blockId, errorMsg) + } else { + logTrace(s"Received block $blockId ($blockSize B) from $server") + blockFetchSuccessCallback(blockId, new ReferenceCountedBuffer(in)) + } + } +} diff --git a/core/src/main/scala/org/apache/spark/network/netty/client/LazyInitIterator.scala b/core/src/main/scala/org/apache/spark/network/netty/client/LazyInitIterator.scala new file mode 100644 index 0000000000000..9740ee64d1f2d --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/netty/client/LazyInitIterator.scala @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty.client + +/** + * A simple iterator that lazily initializes the underlying iterator. + * + * The use case is that sometimes we might have many iterators open at the same time, and each of + * the iterator might initialize its own buffer (e.g. decompression buffer, deserialization buffer). + * This could lead to too many buffers open. If this iterator is used, we lazily initialize those + * buffers. + */ +private[spark] +class LazyInitIterator(createIterator: => Iterator[Any]) extends Iterator[Any] { + + lazy val proxy = createIterator + + override def hasNext: Boolean = { + val gotNext = proxy.hasNext + if (!gotNext) { + close() + } + gotNext + } + + override def next(): Any = proxy.next() + + def close(): Unit = Unit +} diff --git a/core/src/main/scala/org/apache/spark/network/netty/client/ReferenceCountedBuffer.scala b/core/src/main/scala/org/apache/spark/network/netty/client/ReferenceCountedBuffer.scala new file mode 100644 index 0000000000000..ea1abf5eccc26 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/netty/client/ReferenceCountedBuffer.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty.client + +import java.io.InputStream +import java.nio.ByteBuffer + +import io.netty.buffer.{ByteBuf, ByteBufInputStream} + + +/** + * A buffer abstraction based on Netty's ByteBuf so we don't expose Netty. + * This is a Scala value class. + * + * The buffer's life cycle is NOT managed by the JVM, and thus requiring explicit declaration of + * reference by the retain method and release method. + */ +private[spark] +class ReferenceCountedBuffer(val underlying: ByteBuf) extends AnyVal { + + /** Return the nio ByteBuffer view of the underlying buffer. */ + def byteBuffer(): ByteBuffer = underlying.nioBuffer + + /** Creates a new input stream that starts from the current position of the buffer. */ + def inputStream(): InputStream = new ByteBufInputStream(underlying) + + /** Increment the reference counter by one. */ + def retain(): Unit = underlying.retain() + + /** Decrement the reference counter by one and release the buffer if the ref count is 0. */ + def release(): Unit = underlying.release() +} diff --git a/core/src/main/scala/org/apache/spark/network/netty/server/BlockHeader.scala b/core/src/main/scala/org/apache/spark/network/netty/server/BlockHeader.scala new file mode 100644 index 0000000000000..162e9cc6828d4 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/netty/server/BlockHeader.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty.server + +/** + * Header describing a block. This is used only in the server pipeline. + * + * [[BlockServerHandler]] creates this, and [[BlockHeaderEncoder]] encodes it. + * + * @param blockSize length of the block content, excluding the length itself. + * If positive, this is the header for a block (not part of the header). + * If negative, this is the header and content for an error message. + * @param blockId block id + * @param error some error message from reading the block + */ +private[server] +class BlockHeader(val blockSize: Int, val blockId: String, val error: Option[String] = None) diff --git a/core/src/main/scala/org/apache/spark/network/netty/server/BlockHeaderEncoder.scala b/core/src/main/scala/org/apache/spark/network/netty/server/BlockHeaderEncoder.scala new file mode 100644 index 0000000000000..8e4dda4ef8595 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/netty/server/BlockHeaderEncoder.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty.server + +import io.netty.buffer.ByteBuf +import io.netty.channel.ChannelHandlerContext +import io.netty.handler.codec.MessageToByteEncoder + +/** + * A simple encoder for BlockHeader. See [[BlockServer]] for the server to client protocol. + */ +private[server] +class BlockHeaderEncoder extends MessageToByteEncoder[BlockHeader] { + override def encode(ctx: ChannelHandlerContext, msg: BlockHeader, out: ByteBuf): Unit = { + // message = message length (4 bytes) + block id length (4 bytes) + block id + block data + // message length = block id length (4 bytes) + size of block id + size of block data + val blockIdBytes = msg.blockId.getBytes + msg.error match { + case Some(errorMsg) => + val errorBytes = errorMsg.getBytes + out.writeInt(4 + blockIdBytes.length + errorBytes.size) + out.writeInt(-blockIdBytes.length) // use negative block id length to represent errors + out.writeBytes(blockIdBytes) // next is blockId itself + out.writeBytes(errorBytes) // error message + case None => + out.writeInt(4 + blockIdBytes.length + msg.blockSize) + out.writeInt(blockIdBytes.length) // First 4 bytes is blockId length + out.writeBytes(blockIdBytes) // next is blockId itself + // msg of size blockSize will be written by ServerHandler + } + } +} diff --git a/core/src/main/scala/org/apache/spark/network/netty/server/BlockServer.scala b/core/src/main/scala/org/apache/spark/network/netty/server/BlockServer.scala new file mode 100644 index 0000000000000..7b2f9a8d4dfd0 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/netty/server/BlockServer.scala @@ -0,0 +1,162 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty.server + +import java.net.InetSocketAddress + +import io.netty.bootstrap.ServerBootstrap +import io.netty.buffer.PooledByteBufAllocator +import io.netty.channel.{ChannelFuture, ChannelInitializer, ChannelOption} +import io.netty.channel.epoll.{EpollEventLoopGroup, EpollServerSocketChannel} +import io.netty.channel.nio.NioEventLoopGroup +import io.netty.channel.oio.OioEventLoopGroup +import io.netty.channel.socket.SocketChannel +import io.netty.channel.socket.nio.NioServerSocketChannel +import io.netty.channel.socket.oio.OioServerSocketChannel +import io.netty.handler.codec.LineBasedFrameDecoder +import io.netty.handler.codec.string.StringDecoder +import io.netty.util.CharsetUtil + +import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.network.netty.NettyConfig +import org.apache.spark.storage.BlockDataProvider +import org.apache.spark.util.Utils + + +/** + * Server for serving Spark data blocks. + * This should be used together with [[org.apache.spark.network.netty.client.BlockFetchingClient]]. + * + * Protocol for requesting blocks (client to server): + * One block id per line, e.g. to request 3 blocks: "block1\nblock2\nblock3\n" + * + * Protocol for sending blocks (server to client): + * frame-length (4 bytes), block-id-length (4 bytes), block-id, block-data. + * + * frame-length should not include the length of itself. + * If block-id-length is negative, then this is an error message rather than block-data. The real + * length is the absolute value of the frame-length. + * + */ +private[spark] +class BlockServer(conf: NettyConfig, dataProvider: BlockDataProvider) extends Logging { + + def this(sparkConf: SparkConf, dataProvider: BlockDataProvider) = { + this(new NettyConfig(sparkConf), dataProvider) + } + + def port: Int = _port + + def hostName: String = _hostName + + private var _port: Int = conf.serverPort + private var _hostName: String = "" + private var bootstrap: ServerBootstrap = _ + private var channelFuture: ChannelFuture = _ + + init() + + /** Initialize the server. */ + private def init(): Unit = { + bootstrap = new ServerBootstrap + val bossThreadFactory = Utils.namedThreadFactory("spark-shuffle-server-boss") + val workerThreadFactory = Utils.namedThreadFactory("spark-shuffle-server-worker") + + // Use only one thread to accept connections, and 2 * num_cores for worker. + def initNio(): Unit = { + val bossGroup = new NioEventLoopGroup(1, bossThreadFactory) + val workerGroup = new NioEventLoopGroup(0, workerThreadFactory) + workerGroup.setIoRatio(conf.ioRatio) + bootstrap.group(bossGroup, workerGroup).channel(classOf[NioServerSocketChannel]) + } + def initOio(): Unit = { + val bossGroup = new OioEventLoopGroup(1, bossThreadFactory) + val workerGroup = new OioEventLoopGroup(0, workerThreadFactory) + bootstrap.group(bossGroup, workerGroup).channel(classOf[OioServerSocketChannel]) + } + def initEpoll(): Unit = { + val bossGroup = new EpollEventLoopGroup(1, bossThreadFactory) + val workerGroup = new EpollEventLoopGroup(0, workerThreadFactory) + workerGroup.setIoRatio(conf.ioRatio) + bootstrap.group(bossGroup, workerGroup).channel(classOf[EpollServerSocketChannel]) + } + + conf.ioMode match { + case "nio" => initNio() + case "oio" => initOio() + case "epoll" => initEpoll() + case "auto" => + // For auto mode, first try epoll (only available on Linux), then nio. + try { + initEpoll() + } catch { + // TODO: Should we log the throwable? But that always happen on non-Linux systems. + // Perhaps the right thing to do is to check whether the system is Linux, and then only + // call initEpoll on Linux. + case e: Throwable => initNio() + } + } + + // Use pooled buffers to reduce temporary buffer allocation + bootstrap.option(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT) + bootstrap.childOption(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT) + + // Various (advanced) user-configured settings. + conf.backLog.foreach { backLog => + bootstrap.option[java.lang.Integer](ChannelOption.SO_BACKLOG, backLog) + } + conf.receiveBuf.foreach { receiveBuf => + bootstrap.option[java.lang.Integer](ChannelOption.SO_RCVBUF, receiveBuf) + } + conf.sendBuf.foreach { sendBuf => + bootstrap.option[java.lang.Integer](ChannelOption.SO_SNDBUF, sendBuf) + } + + bootstrap.childHandler(new ChannelInitializer[SocketChannel] { + override def initChannel(ch: SocketChannel): Unit = { + ch.pipeline + .addLast("frameDecoder", new LineBasedFrameDecoder(1024)) // max block id length 1024 + .addLast("stringDecoder", new StringDecoder(CharsetUtil.UTF_8)) + .addLast("blockHeaderEncoder", new BlockHeaderEncoder) + .addLast("handler", new BlockServerHandler(dataProvider)) + } + }) + + channelFuture = bootstrap.bind(new InetSocketAddress(_port)) + channelFuture.sync() + + val addr = channelFuture.channel.localAddress.asInstanceOf[InetSocketAddress] + _port = addr.getPort + _hostName = addr.getHostName + } + + /** Shutdown the server. */ + def stop(): Unit = { + if (channelFuture != null) { + channelFuture.channel().close().awaitUninterruptibly() + channelFuture = null + } + if (bootstrap != null && bootstrap.group() != null) { + bootstrap.group().shutdownGracefully() + } + if (bootstrap != null && bootstrap.childGroup() != null) { + bootstrap.childGroup().shutdownGracefully() + } + bootstrap = null + } +} diff --git a/core/src/main/scala/org/apache/spark/network/netty/FileServerChannelInitializer.scala b/core/src/main/scala/org/apache/spark/network/netty/server/BlockServerChannelInitializer.scala similarity index 58% rename from core/src/main/scala/org/apache/spark/network/netty/FileServerChannelInitializer.scala rename to core/src/main/scala/org/apache/spark/network/netty/server/BlockServerChannelInitializer.scala index aaa2f913d0269..cc70bd0c5c477 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/FileServerChannelInitializer.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/server/BlockServerChannelInitializer.scala @@ -15,20 +15,26 @@ * limitations under the License. */ -package org.apache.spark.network.netty +package org.apache.spark.network.netty.server import io.netty.channel.ChannelInitializer import io.netty.channel.socket.SocketChannel -import io.netty.handler.codec.{DelimiterBasedFrameDecoder, Delimiters} +import io.netty.handler.codec.LineBasedFrameDecoder import io.netty.handler.codec.string.StringDecoder +import io.netty.util.CharsetUtil +import org.apache.spark.storage.BlockDataProvider -class FileServerChannelInitializer(pResolver: PathResolver) + +/** Channel initializer that sets up the pipeline for the BlockServer. */ +private[netty] +class BlockServerChannelInitializer(dataProvider: BlockDataProvider) extends ChannelInitializer[SocketChannel] { - override def initChannel(channel: SocketChannel): Unit = { - channel.pipeline - .addLast("framer", new DelimiterBasedFrameDecoder(8192, Delimiters.lineDelimiter : _*)) - .addLast("stringDecoder", new StringDecoder) - .addLast("handler", new FileServerHandler(pResolver)) + override def initChannel(ch: SocketChannel): Unit = { + ch.pipeline + .addLast("frameDecoder", new LineBasedFrameDecoder(1024)) // max block id length 1024 + .addLast("stringDecoder", new StringDecoder(CharsetUtil.UTF_8)) + .addLast("blockHeaderEncoder", new BlockHeaderEncoder) + .addLast("handler", new BlockServerHandler(dataProvider)) } } diff --git a/core/src/main/scala/org/apache/spark/network/netty/server/BlockServerHandler.scala b/core/src/main/scala/org/apache/spark/network/netty/server/BlockServerHandler.scala new file mode 100644 index 0000000000000..40dd5e5d1a2ac --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/netty/server/BlockServerHandler.scala @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty.server + +import java.io.FileInputStream +import java.nio.ByteBuffer +import java.nio.channels.FileChannel + +import io.netty.buffer.Unpooled +import io.netty.channel._ + +import org.apache.spark.Logging +import org.apache.spark.storage.{FileSegment, BlockDataProvider} + + +/** + * A handler that processes requests from clients and writes block data back. + * + * The messages should have been processed by a LineBasedFrameDecoder and a StringDecoder first + * so channelRead0 is called once per line (i.e. per block id). + */ +private[server] +class BlockServerHandler(dataProvider: BlockDataProvider) + extends SimpleChannelInboundHandler[String] with Logging { + + override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = { + logError(s"Exception in connection from ${ctx.channel.remoteAddress}", cause) + ctx.close() + } + + override def channelRead0(ctx: ChannelHandlerContext, blockId: String): Unit = { + def client = ctx.channel.remoteAddress.toString + + // A helper function to send error message back to the client. + def respondWithError(error: String): Unit = { + ctx.writeAndFlush(new BlockHeader(-1, blockId, Some(error))).addListener( + new ChannelFutureListener { + override def operationComplete(future: ChannelFuture) { + if (!future.isSuccess) { + // TODO: Maybe log the success case as well. + logError(s"Error sending error back to $client", future.cause) + ctx.close() + } + } + } + ) + } + + def writeFileSegment(segment: FileSegment): Unit = { + // Send error message back if the block is too large. Even though we are capable of sending + // large (2G+) blocks, the receiving end cannot handle it so let's fail fast. + // Once we fixed the receiving end to be able to process large blocks, this should be removed. + // Also make sure we update BlockHeaderEncoder to support length > 2G. + + // See [[BlockHeaderEncoder]] for the way length is encoded. + if (segment.length + blockId.length + 4 > Int.MaxValue) { + respondWithError(s"Block $blockId size ($segment.length) greater than 2G") + return + } + + var fileChannel: FileChannel = null + try { + fileChannel = new FileInputStream(segment.file).getChannel + } catch { + case e: Exception => + logError( + s"Error opening channel for $blockId in ${segment.file} for request from $client", e) + respondWithError(e.getMessage) + } + + // Found the block. Send it back. + if (fileChannel != null) { + // Write the header and block data. In the case of failures, the listener on the block data + // write should close the connection. + ctx.write(new BlockHeader(segment.length.toInt, blockId)) + + val region = new DefaultFileRegion(fileChannel, segment.offset, segment.length) + ctx.writeAndFlush(region).addListener(new ChannelFutureListener { + override def operationComplete(future: ChannelFuture) { + if (future.isSuccess) { + logTrace(s"Sent block $blockId (${segment.length} B) back to $client") + } else { + logError(s"Error sending block $blockId to $client; closing connection", future.cause) + ctx.close() + } + } + }) + } + } + + def writeByteBuffer(buf: ByteBuffer): Unit = { + ctx.write(new BlockHeader(buf.remaining, blockId)) + ctx.writeAndFlush(Unpooled.wrappedBuffer(buf)).addListener(new ChannelFutureListener { + override def operationComplete(future: ChannelFuture) { + if (future.isSuccess) { + logTrace(s"Sent block $blockId (${buf.remaining} B) back to $client") + } else { + logError(s"Error sending block $blockId to $client; closing connection", future.cause) + ctx.close() + } + } + }) + } + + logTrace(s"Received request from $client to fetch block $blockId") + + var blockData: Either[FileSegment, ByteBuffer] = null + + // First make sure we can find the block. If not, send error back to the user. + try { + blockData = dataProvider.getBlockData(blockId) + } catch { + case e: Exception => + logError(s"Error opening block $blockId for request from $client", e) + respondWithError(e.getMessage) + return + } + + blockData match { + case Left(segment) => writeFileSegment(segment) + case Right(buf) => writeByteBuffer(buf) + } + + } // end of channelRead0 +} diff --git a/core/src/main/scala/org/apache/spark/network/netty/FileClientChannelInitializer.scala b/core/src/main/scala/org/apache/spark/storage/BlockDataProvider.scala similarity index 65% rename from core/src/main/scala/org/apache/spark/network/netty/FileClientChannelInitializer.scala rename to core/src/main/scala/org/apache/spark/storage/BlockDataProvider.scala index f4261c13f70a8..5b6d086630834 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/FileClientChannelInitializer.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockDataProvider.scala @@ -15,17 +15,18 @@ * limitations under the License. */ -package org.apache.spark.network.netty +package org.apache.spark.storage -import io.netty.channel.ChannelInitializer -import io.netty.channel.socket.SocketChannel -import io.netty.handler.codec.string.StringEncoder +import java.nio.ByteBuffer -class FileClientChannelInitializer(handler: FileClientHandler) - extends ChannelInitializer[SocketChannel] { - - def initChannel(channel: SocketChannel) { - channel.pipeline.addLast("encoder", new StringEncoder).addLast("handler", handler) - } +/** + * An interface for providing data for blocks. + * + * getBlockData returns either a FileSegment (for zero-copy send), or a ByteBuffer. + * + * Aside from unit tests, [[BlockManager]] is the main class that implements this. + */ +private[spark] trait BlockDataProvider { + def getBlockData(blockId: String): Either[FileSegment, ByteBuffer] } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala index 5f44f5f3197fd..91c0f47d51d02 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala @@ -18,19 +18,17 @@ package org.apache.spark.storage import java.util.concurrent.LinkedBlockingQueue +import org.apache.spark.network.netty.client.{LazyInitIterator, ReferenceCountedBuffer} import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashSet import scala.collection.mutable.Queue import scala.util.{Failure, Success} -import io.netty.buffer.ByteBuf - import org.apache.spark.{Logging, SparkException} import org.apache.spark.executor.ShuffleReadMetrics import org.apache.spark.network.BufferMessage import org.apache.spark.network.ConnectionManagerId -import org.apache.spark.network.netty.ShuffleCopier import org.apache.spark.serializer.Serializer import org.apache.spark.util.Utils @@ -54,18 +52,28 @@ trait BlockFetcherIterator extends Iterator[(BlockId, Option[Iterator[Any]])] wi private[storage] object BlockFetcherIterator { - // A request to fetch one or more blocks, complete with their sizes + /** + * A request to fetch blocks from a remote BlockManager. + * @param address remote BlockManager to fetch from. + * @param blocks Sequence of tuple, where the first element is the block id, + * and the second element is the estimated size, used to calculate bytesInFlight. + */ class FetchRequest(val address: BlockManagerId, val blocks: Seq[(BlockId, Long)]) { val size = blocks.map(_._2).sum } - // A result of a fetch. Includes the block ID, size in bytes, and a function to deserialize - // the block (since we want all deserializaton to happen in the calling thread); can also - // represent a fetch failure if size == -1. + /** + * Result of a fetch from a remote block. A failure is represented as size == -1. + * @param blockId block id + * @param size estimated size of the block, used to calculate bytesInFlight. + * Note that this is NOT the exact bytes. + * @param deserialize closure to return the result in the form of an Iterator. + */ class FetchResult(val blockId: BlockId, val size: Long, val deserialize: () => Iterator[Any]) { def failed: Boolean = size == -1 } + // TODO: Refactor this whole thing to make code more reusable. class BasicBlockFetcherIterator( private val blockManager: BlockManager, val blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])], @@ -95,10 +103,10 @@ object BlockFetcherIterator { // Queue of fetch requests to issue; we'll pull requests off this gradually to make sure that // the number of bytes in flight is limited to maxBytesInFlight - private val fetchRequests = new Queue[FetchRequest] + protected val fetchRequests = new Queue[FetchRequest] // Current bytes in flight from our requests - private var bytesInFlight = 0L + protected var bytesInFlight = 0L protected def sendRequest(req: FetchRequest) { logDebug("Sending request for %d blocks (%s) from %s".format( @@ -262,77 +270,55 @@ object BlockFetcherIterator { readMetrics: ShuffleReadMetrics) extends BasicBlockFetcherIterator(blockManager, blocksByAddress, serializer, readMetrics) { - import blockManager._ - - val fetchRequestsSync = new LinkedBlockingQueue[FetchRequest] - - private def startCopiers(numCopiers: Int): List[_ <: Thread] = { - (for ( i <- Range(0,numCopiers) ) yield { - val copier = new Thread { - override def run(){ - try { - while(!isInterrupted && !fetchRequestsSync.isEmpty) { - sendRequest(fetchRequestsSync.take()) - } - } catch { - case x: InterruptedException => logInfo("Copier Interrupted") - // case _ => throw new SparkException("Exception Throw in Shuffle Copier") - } - } - } - copier.start - copier - }).toList - } - - // keep this to interrupt the threads when necessary - private def stopCopiers() { - for (copier <- copiers) { - copier.interrupt() - } - } - override protected def sendRequest(req: FetchRequest) { - - def putResult(blockId: BlockId, blockSize: Long, blockData: ByteBuf) { - val fetchResult = new FetchResult(blockId, blockSize, - () => dataDeserialize(blockId, blockData.nioBuffer, serializer)) - results.put(fetchResult) - } - logDebug("Sending request for %d blocks (%s) from %s".format( - req.blocks.size, Utils.bytesToString(req.size), req.address.host)) - val cmId = new ConnectionManagerId(req.address.host, req.address.nettyPort) - val cpier = new ShuffleCopier(blockManager.conf) - cpier.getBlocks(cmId, req.blocks, putResult) - logDebug("Sent request for remote blocks " + req.blocks + " from " + req.address.host ) - } - - private var copiers: List[_ <: Thread] = null - - override def initialize() { - // Split Local Remote Blocks and set numBlocksToFetch - val remoteRequests = splitLocalRemoteBlocks() - // Add the remote requests into our queue in a random order - for (request <- Utils.randomize(remoteRequests)) { - fetchRequestsSync.put(request) - } - - copiers = startCopiers(conf.getInt("spark.shuffle.copier.threads", 6)) - logInfo("Started " + fetchRequestsSync.size + " remote fetches in " + - Utils.getUsedTimeMs(startTime)) + req.blocks.size, Utils.bytesToString(req.size), req.address.hostPort)) + val cmId = new ConnectionManagerId(req.address.host, req.address.port) - // Get Local Blocks - startTime = System.currentTimeMillis - getLocalBlocks() - logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms") - } + bytesInFlight += req.size + val sizeMap = req.blocks.toMap // so we can look up the size of each blockID + + // This could throw a TimeoutException. In that case we will just retry the task. + val client = blockManager.nettyBlockClientFactory.createClient( + cmId.host, req.address.nettyPort) + val blocks = req.blocks.map(_._1.toString) + + client.fetchBlocks( + blocks, + (blockId: String, refBuf: ReferenceCountedBuffer) => { + // Increment the reference count so the buffer won't be recycled. + // TODO: This could result in memory leaks when the task is stopped due to exception + // before the iterator is exhausted. + refBuf.retain() + val buf = refBuf.byteBuffer() + val blockSize = buf.remaining() + val bid = BlockId(blockId) + + // TODO: remove code duplication between here and BlockManager.dataDeserialization. + results.put(new FetchResult(bid, sizeMap(bid), () => { + def createIterator: Iterator[Any] = { + val stream = blockManager.wrapForCompression(bid, refBuf.inputStream()) + serializer.newInstance().deserializeStream(stream).asIterator + } + new LazyInitIterator(createIterator) { + // Release the buffer when we are done traversing it. + override def close(): Unit = refBuf.release() + } + })) - override def next(): (BlockId, Option[Iterator[Any]]) = { - resultsGotten += 1 - val result = results.take() - // If all the results has been retrieved, copiers will exit automatically - (result.blockId, if (result.failed) None else Some(result.deserialize())) + readMetrics.synchronized { + readMetrics.remoteBytesRead += blockSize + readMetrics.remoteBlocksFetched += 1 + } + logDebug("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime)) + }, + (blockId: String, errorMsg: String) => { + logError(s"Could not get block(s) from $cmId with error: $errorMsg") + for ((blockId, size) <- req.blocks) { + results.put(new FetchResult(blockId, -1, null)) + } + } + ) } } // End of NettyBlockFetcherIterator diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index e8bbd298c631a..e67676950b0ed 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -25,16 +25,19 @@ import scala.concurrent.{Await, Future} import scala.concurrent.duration._ import scala.util.Random -import akka.actor.{ActorSystem, Cancellable, Props} +import akka.actor.{ActorSystem, Props} import sun.nio.ch.DirectBuffer import org.apache.spark._ import org.apache.spark.executor._ import org.apache.spark.io.CompressionCodec import org.apache.spark.network._ +import org.apache.spark.network.netty.client.BlockFetchingClientFactory +import org.apache.spark.network.netty.server.BlockServer import org.apache.spark.serializer.Serializer import org.apache.spark.util._ + private[spark] sealed trait BlockValues private[spark] case class ByteBufferValues(buffer: ByteBuffer) extends BlockValues private[spark] case class IteratorValues(iterator: Iterator[Any]) extends BlockValues @@ -58,7 +61,7 @@ private[spark] class BlockManager( val conf: SparkConf, securityManager: SecurityManager, mapOutputTracker: MapOutputTracker) - extends Logging { + extends BlockDataProvider with Logging { private val port = conf.getInt("spark.blockManager.port", 0) val shuffleBlockManager = new ShuffleBlockManager(this) @@ -86,13 +89,25 @@ private[spark] class BlockManager( new TachyonStore(this, tachyonBlockManager) } + private val useNetty = conf.getBoolean("spark.shuffle.use.netty", false) + // If we use Netty for shuffle, start a new Netty-based shuffle sender service. - private val nettyPort: Int = { - val useNetty = conf.getBoolean("spark.shuffle.use.netty", false) - val nettyPortConfig = conf.getInt("spark.shuffle.sender.port", 0) - if (useNetty) diskBlockManager.startShuffleBlockSender(nettyPortConfig) else 0 + private[storage] val nettyBlockClientFactory: BlockFetchingClientFactory = { + if (useNetty) new BlockFetchingClientFactory(conf) else null } + private val nettyBlockServer: BlockServer = { + if (useNetty) { + val server = new BlockServer(conf, this) + logInfo(s"Created NettyBlockServer binding to port: ${server.port}") + server + } else { + null + } + } + + private val nettyPort: Int = if (useNetty) nettyBlockServer.port else 0 + val blockManagerId = BlockManagerId( executorId, connectionManager.id.host, connectionManager.id.port, nettyPort) @@ -216,6 +231,20 @@ private[spark] class BlockManager( } } + override def getBlockData(blockId: String): Either[FileSegment, ByteBuffer] = { + val bid = BlockId(blockId) + if (bid.isShuffle) { + Left(diskBlockManager.getBlockLocation(bid)) + } else { + val blockBytesOpt = doGetLocal(bid, asBlockResult = false).asInstanceOf[Option[ByteBuffer]] + if (blockBytesOpt.isDefined) { + Right(blockBytesOpt.get) + } else { + throw new BlockNotFoundException(blockId) + } + } + } + /** * Get the BlockStatus for the block identified by the given ID, if it exists. * NOTE: This is mainly for testing, and it doesn't fetch information from Tachyon. @@ -1061,6 +1090,14 @@ private[spark] class BlockManager( connectionManager.stop() shuffleBlockManager.stop() diskBlockManager.stop() + + if (nettyBlockClientFactory != null) { + nettyBlockClientFactory.stop() + } + if (nettyBlockServer != null) { + nettyBlockServer.stop() + } + actorSystem.stop(slaveActor) blockInfo.clear() memoryStore.clear() diff --git a/core/src/main/scala/org/apache/spark/storage/BlockNotFoundException.scala b/core/src/main/scala/org/apache/spark/storage/BlockNotFoundException.scala new file mode 100644 index 0000000000000..9ef453605f4f1 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/storage/BlockNotFoundException.scala @@ -0,0 +1,21 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + + +class BlockNotFoundException(blockId: String) extends Exception(s"Block $blockId not found") diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala index 4d66ccea211fa..f3da816389581 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala @@ -23,7 +23,7 @@ import java.util.{Date, Random, UUID} import org.apache.spark.{SparkEnv, Logging} import org.apache.spark.executor.ExecutorExitCode -import org.apache.spark.network.netty.{PathResolver, ShuffleSender} +import org.apache.spark.network.netty.PathResolver import org.apache.spark.util.Utils import org.apache.spark.shuffle.sort.SortShuffleManager @@ -52,7 +52,6 @@ private[spark] class DiskBlockManager(shuffleBlockManager: ShuffleBlockManager, System.exit(ExecutorExitCode.DISK_STORE_FAILED_TO_CREATE_DIR) } private val subDirs = Array.fill(localDirs.length)(new Array[File](subDirsPerLocalDir)) - private var shuffleSender : ShuffleSender = null addShutdownHook() @@ -186,15 +185,5 @@ private[spark] class DiskBlockManager(shuffleBlockManager: ShuffleBlockManager, } } } - - if (shuffleSender != null) { - shuffleSender.stop() - } - } - - private[storage] def startShuffleBlockSender(port: Int): Int = { - shuffleSender = new ShuffleSender(port, this) - logInfo(s"Created ShuffleSender binding to port: ${shuffleSender.port}") - shuffleSender.port } } diff --git a/core/src/test/resources/netty-test-file.txt b/core/src/test/resources/netty-test-file.txt new file mode 100644 index 0000000000000..f59f293ee02ea --- /dev/null +++ b/core/src/test/resources/netty-test-file.txt @@ -0,0 +1,1379 @@ +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb +eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa \ No newline at end of file diff --git a/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala new file mode 100644 index 0000000000000..ef3478a41e912 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty + +import java.io.{RandomAccessFile, File} +import java.nio.ByteBuffer +import java.util.{Collections, HashSet} +import java.util.concurrent.{TimeUnit, Semaphore} + +import scala.collection.JavaConversions._ + +import io.netty.buffer.{ByteBufUtil, Unpooled} + +import org.scalatest.{BeforeAndAfterAll, FunSuite} + +import org.apache.spark.SparkConf +import org.apache.spark.network.netty.client.{ReferenceCountedBuffer, BlockFetchingClientFactory} +import org.apache.spark.network.netty.server.BlockServer +import org.apache.spark.storage.{FileSegment, BlockDataProvider} + + +/** + * Test suite that makes sure the server and the client implementations share the same protocol. + */ +class ServerClientIntegrationSuite extends FunSuite with BeforeAndAfterAll { + + val bufSize = 100000 + var buf: ByteBuffer = _ + var testFile: File = _ + var server: BlockServer = _ + var clientFactory: BlockFetchingClientFactory = _ + + val bufferBlockId = "buffer_block" + val fileBlockId = "file_block" + + val fileContent = new Array[Byte](1024) + scala.util.Random.nextBytes(fileContent) + + override def beforeAll() = { + buf = ByteBuffer.allocate(bufSize) + for (i <- 1 to bufSize) { + buf.put(i.toByte) + } + buf.flip() + + testFile = File.createTempFile("netty-test-file", "txt") + val fp = new RandomAccessFile(testFile, "rw") + fp.write(fileContent) + fp.close() + + server = new BlockServer(new SparkConf, new BlockDataProvider { + override def getBlockData(blockId: String): Either[FileSegment, ByteBuffer] = { + if (blockId == bufferBlockId) { + Right(buf) + } else if (blockId == fileBlockId) { + Left(new FileSegment(testFile, 10, testFile.length - 25)) + } else { + throw new Exception("Unknown block id " + blockId) + } + } + }) + + clientFactory = new BlockFetchingClientFactory(new SparkConf) + } + + override def afterAll() = { + server.stop() + clientFactory.stop() + } + + /** A ByteBuf for buffer_block */ + lazy val byteBufferBlockReference = Unpooled.wrappedBuffer(buf) + + /** A ByteBuf for file_block */ + lazy val fileBlockReference = Unpooled.wrappedBuffer(fileContent, 10, fileContent.length - 25) + + def fetchBlocks(blockIds: Seq[String]): (Set[String], Set[ReferenceCountedBuffer], Set[String]) = + { + val client = clientFactory.createClient(server.hostName, server.port) + val sem = new Semaphore(0) + val receivedBlockIds = Collections.synchronizedSet(new HashSet[String]) + val errorBlockIds = Collections.synchronizedSet(new HashSet[String]) + val receivedBuffers = Collections.synchronizedSet(new HashSet[ReferenceCountedBuffer]) + + client.fetchBlocks( + blockIds, + (blockId, buf) => { + receivedBlockIds.add(blockId) + buf.retain() + receivedBuffers.add(buf) + sem.release() + }, + (blockId, errorMsg) => { + errorBlockIds.add(blockId) + sem.release() + } + ) + if (!sem.tryAcquire(blockIds.size, 30, TimeUnit.SECONDS)) { + fail("Timeout getting response from the server") + } + client.close() + (receivedBlockIds.toSet, receivedBuffers.toSet, errorBlockIds.toSet) + } + + test("fetch a ByteBuffer block") { + val (blockIds, buffers, failBlockIds) = fetchBlocks(Seq(bufferBlockId)) + assert(blockIds === Set(bufferBlockId)) + assert(buffers.map(_.underlying) === Set(byteBufferBlockReference)) + assert(failBlockIds.isEmpty) + buffers.foreach(_.release()) + } + + test("fetch a FileSegment block via zero-copy send") { + val (blockIds, buffers, failBlockIds) = fetchBlocks(Seq(fileBlockId)) + assert(blockIds === Set(fileBlockId)) + assert(buffers.map(_.underlying) === Set(fileBlockReference)) + assert(failBlockIds.isEmpty) + buffers.foreach(_.release()) + } + + test("fetch a non-existent block") { + val (blockIds, buffers, failBlockIds) = fetchBlocks(Seq("random-block")) + assert(blockIds.isEmpty) + assert(buffers.isEmpty) + assert(failBlockIds === Set("random-block")) + } + + test("fetch both ByteBuffer block and FileSegment block") { + val (blockIds, buffers, failBlockIds) = fetchBlocks(Seq(bufferBlockId, fileBlockId)) + assert(blockIds === Set(bufferBlockId, fileBlockId)) + assert(buffers.map(_.underlying) === Set(byteBufferBlockReference, fileBlockReference)) + assert(failBlockIds.isEmpty) + buffers.foreach(_.release()) + } + + test("fetch both ByteBuffer block and a non-existent block") { + val (blockIds, buffers, failBlockIds) = fetchBlocks(Seq(bufferBlockId, "random-block")) + assert(blockIds === Set(bufferBlockId)) + assert(buffers.map(_.underlying) === Set(byteBufferBlockReference)) + assert(failBlockIds === Set("random-block")) + buffers.foreach(_.release()) + } +} diff --git a/core/src/test/scala/org/apache/spark/network/netty/client/BlockFetchingClientHandlerSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/client/BlockFetchingClientHandlerSuite.scala new file mode 100644 index 0000000000000..9afdad63b6988 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/network/netty/client/BlockFetchingClientHandlerSuite.scala @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty.client + +import java.nio.ByteBuffer + +import io.netty.buffer.Unpooled +import io.netty.channel.embedded.EmbeddedChannel + +import org.scalatest.FunSuite + + +class BlockFetchingClientHandlerSuite extends FunSuite { + + test("handling block data (successful fetch)") { + val blockId = "test_block" + val blockData = "blahblahblahblahblah" + val totalLength = 4 + blockId.length + blockData.length + + var parsedBlockId: String = "" + var parsedBlockData: String = "" + val handler = new BlockFetchingClientHandler + handler.blockFetchSuccessCallback = (bid, refCntBuf) => { + parsedBlockId = bid + val bytes = new Array[Byte](refCntBuf.byteBuffer().remaining) + refCntBuf.byteBuffer().get(bytes) + parsedBlockData = new String(bytes) + } + + val channel = new EmbeddedChannel(handler) + val buf = ByteBuffer.allocate(totalLength + 4) // 4 bytes for the length field itself + buf.putInt(totalLength) + buf.putInt(blockId.length) + buf.put(blockId.getBytes) + buf.put(blockData.getBytes) + buf.flip() + + channel.writeInbound(Unpooled.wrappedBuffer(buf)) + assert(parsedBlockId === blockId) + assert(parsedBlockData === blockData) + + channel.close() + } + + test("handling error message (failed fetch)") { + val blockId = "test_block" + val errorMsg = "error erro5r error err4or error3 error6 error erro1r" + val totalLength = 4 + blockId.length + errorMsg.length + + var parsedBlockId: String = "" + var parsedErrorMsg: String = "" + val handler = new BlockFetchingClientHandler + handler.blockFetchFailureCallback = (bid, msg) => { + parsedBlockId = bid + parsedErrorMsg = msg + } + + val channel = new EmbeddedChannel(handler) + val buf = ByteBuffer.allocate(totalLength + 4) // 4 bytes for the length field itself + buf.putInt(totalLength) + buf.putInt(-blockId.length) + buf.put(blockId.getBytes) + buf.put(errorMsg.getBytes) + buf.flip() + + channel.writeInbound(Unpooled.wrappedBuffer(buf)) + assert(parsedBlockId === blockId) + assert(parsedErrorMsg === errorMsg) + + channel.close() + } +} diff --git a/core/src/test/scala/org/apache/spark/network/netty/server/BlockHeaderEncoderSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/server/BlockHeaderEncoderSuite.scala new file mode 100644 index 0000000000000..3ee281cb1350b --- /dev/null +++ b/core/src/test/scala/org/apache/spark/network/netty/server/BlockHeaderEncoderSuite.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty.server + +import io.netty.buffer.ByteBuf +import io.netty.channel.embedded.EmbeddedChannel + +import org.scalatest.FunSuite + + +class BlockHeaderEncoderSuite extends FunSuite { + + test("encode normal block data") { + val blockId = "test_block" + val channel = new EmbeddedChannel(new BlockHeaderEncoder) + channel.writeOutbound(new BlockHeader(17, blockId, None)) + val out = channel.readOutbound().asInstanceOf[ByteBuf] + assert(out.readInt() === 4 + blockId.length + 17) + assert(out.readInt() === blockId.length) + + val blockIdBytes = new Array[Byte](blockId.length) + out.readBytes(blockIdBytes) + assert(new String(blockIdBytes) === blockId) + assert(out.readableBytes() === 0) + + channel.close() + } + + test("encode error message") { + val blockId = "error_block" + val errorMsg = "error encountered" + val channel = new EmbeddedChannel(new BlockHeaderEncoder) + channel.writeOutbound(new BlockHeader(17, blockId, Some(errorMsg))) + val out = channel.readOutbound().asInstanceOf[ByteBuf] + assert(out.readInt() === 4 + blockId.length + errorMsg.length) + assert(out.readInt() === -blockId.length) + + val blockIdBytes = new Array[Byte](blockId.length) + out.readBytes(blockIdBytes) + assert(new String(blockIdBytes) === blockId) + + val errorMsgBytes = new Array[Byte](errorMsg.length) + out.readBytes(errorMsgBytes) + assert(new String(errorMsgBytes) === errorMsg) + assert(out.readableBytes() === 0) + + channel.close() + } +} diff --git a/core/src/test/scala/org/apache/spark/network/netty/server/BlockServerHandlerSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/server/BlockServerHandlerSuite.scala new file mode 100644 index 0000000000000..12f6d87616644 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/network/netty/server/BlockServerHandlerSuite.scala @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty.server + +import java.io.File +import java.nio.ByteBuffer + +import io.netty.buffer.{Unpooled, ByteBuf} +import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler, DefaultFileRegion} +import io.netty.channel.embedded.EmbeddedChannel + +import org.scalatest.FunSuite + +import org.apache.spark.storage.{BlockDataProvider, FileSegment} + + +class BlockServerHandlerSuite extends FunSuite { + + test("ByteBuffer block") { + val expectedBlockId = "test_bytebuffer_block" + val buf = ByteBuffer.allocate(10000) + for (i <- 1 to 10000) { + buf.put(i.toByte) + } + buf.flip() + + val channel = new EmbeddedChannel(new BlockServerHandler(new BlockDataProvider { + override def getBlockData(blockId: String): Either[FileSegment, ByteBuffer] = Right(buf) + })) + + channel.writeInbound(expectedBlockId) + assert(channel.outboundMessages().size === 2) + + val out1 = channel.readOutbound().asInstanceOf[BlockHeader] + val out2 = channel.readOutbound().asInstanceOf[ByteBuf] + + assert(out1.blockId === expectedBlockId) + assert(out1.blockSize === buf.remaining) + assert(out1.error === None) + + assert(out2.equals(Unpooled.wrappedBuffer(buf))) + + channel.close() + } + + test("FileSegment block via zero-copy") { + val expectedBlockId = "test_file_block" + val url = Thread.currentThread.getContextClassLoader.getResource("netty-test-file.txt") + val testFile = new File(url.toURI) + + val channel = new EmbeddedChannel(new BlockServerHandler(new BlockDataProvider { + override def getBlockData(blockId: String): Either[FileSegment, ByteBuffer] = { + Left(new FileSegment(testFile, 15, testFile.length - 25)) + } + })) + + channel.writeInbound(expectedBlockId) + assert(channel.outboundMessages().size === 2) + + val out1 = channel.readOutbound().asInstanceOf[BlockHeader] + val out2 = channel.readOutbound().asInstanceOf[DefaultFileRegion] + + assert(out1.blockId === expectedBlockId) + assert(out1.blockSize === testFile.length - 25) + assert(out1.error === None) + + assert(out2.count === testFile.length - 25) + assert(out2.position === 15) + } + + test("pipeline exception propagation") { + val blockServerHandler = new BlockServerHandler(new BlockDataProvider { + override def getBlockData(blockId: String): Either[FileSegment, ByteBuffer] = ??? + }) + val exceptionHandler = new SimpleChannelInboundHandler[String]() { + override def channelRead0(ctx: ChannelHandlerContext, msg: String): Unit = { + throw new Exception("this is an error") + } + } + + val channel = new EmbeddedChannel(exceptionHandler, blockServerHandler) + assert(channel.isOpen) + channel.writeInbound("a message to trigger the error") + assert(!channel.isOpen) + } +} diff --git a/pom.xml b/pom.xml index c87f776bda659..da401c9753347 100644 --- a/pom.xml +++ b/pom.xml @@ -419,7 +419,7 @@ io.netty netty-all - 4.0.17.Final + 4.0.22.Final org.apache.derby From debb3e3df601bc64c97701565d2c992855f6cce9 Mon Sep 17 00:00:00 2001 From: Anand Avati Date: Fri, 15 Aug 2014 08:53:52 -0700 Subject: [PATCH 146/231] [SPARK-2924] remove default args to overloaded methods Not supported in Scala 2.11. Split them into separate methods instead. Author: Anand Avati Closes #1704 from avati/SPARK-1812-default-args and squashes the following commits: 3e3924a [Anand Avati] SPARK-1812: Add Mima excludes for the broken ABI 901dfc7 [Anand Avati] SPARK-1812: core - Fix overloaded methods with default arguments 07f00af [Anand Avati] SPARK-1812: streaming - Fix overloaded methods with default arguments (cherry picked from commit 7589c39d39a8d0744fb689e5752ee8e0108a81eb) Signed-off-by: Patrick Wendell --- core/src/main/scala/org/apache/spark/ui/JettyUtils.scala | 2 +- project/MimaExcludes.scala | 3 +++ .../org/apache/spark/streaming/StreamingContext.scala | 8 +++++++- 3 files changed, 11 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala index 29e9cf947856f..6b4689291097f 100644 --- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala @@ -93,7 +93,7 @@ private[spark] object JettyUtils extends Logging { def createServletHandler( path: String, servlet: HttpServlet, - basePath: String = ""): ServletContextHandler = { + basePath: String): ServletContextHandler = { val prefixedPath = attachPrefix(basePath, path) val contextHandler = new ServletContextHandler val holder = new ServletHolder(servlet) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 6e72035f2c15b..1e3c760b845de 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -117,6 +117,9 @@ object MimaExcludes { ) ++ Seq( // new Vector methods in MLlib (binary compatible assuming users do not implement Vector) ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.mllib.linalg.Vector.copy") + ) ++ + Seq ( // Scala 2.11 compatibility fix + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.StreamingContext.$default$2") ) case v if v.startsWith("1.0") => Seq( diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index e0677b795cb94..101cec1c7a7c2 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -98,9 +98,15 @@ class StreamingContext private[streaming] ( * @param hadoopConf Optional, configuration object if necessary for reading from * HDFS compatible filesystems */ - def this(path: String, hadoopConf: Configuration = new Configuration) = + def this(path: String, hadoopConf: Configuration) = this(null, CheckpointReader.read(path, new SparkConf(), hadoopConf).get, null) + /** + * Recreate a StreamingContext from a checkpoint file. + * @param path Path to the directory that was specified as the checkpoint directory + */ + def this(path: String) = this(path, new Configuration) + if (sc_ == null && cp_ == null) { throw new Exception("Spark Streaming cannot be initialized with " + "both SparkContext and checkpoint as null") From b066af4efb8dc544576f9f818d4974ac129c2ba7 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Fri, 15 Aug 2014 09:01:35 -0700 Subject: [PATCH 147/231] Revert "[SPARK-2468] Netty based block server / client module" This reverts commit 3f23d2a38c3b6559902bc2ab6975ff6b0bec875e. --- .../spark/network/netty/FileClient.scala | 85 + .../netty/FileClientChannelInitializer.scala} | 21 +- .../network/netty/FileClientHandler.scala | 50 + .../spark/network/netty/FileHeader.scala | 71 + .../spark/network/netty/FileServer.scala | 91 ++ ...ala => FileServerChannelInitializer.scala} | 22 +- .../network/netty/FileServerHandler.scala | 68 + .../spark/network/netty/NettyConfig.scala | 59 - .../spark/network/netty/ShuffleCopier.scala | 118 ++ .../spark/network/netty/ShuffleSender.scala | 71 + .../netty/client/BlockFetchingClient.scala | 135 -- .../client/BlockFetchingClientFactory.scala | 99 -- .../client/BlockFetchingClientHandler.scala | 63 - .../netty/client/LazyInitIterator.scala | 44 - .../netty/client/ReferenceCountedBuffer.scala | 47 - .../network/netty/server/BlockHeader.scala | 32 - .../netty/server/BlockHeaderEncoder.scala | 47 - .../network/netty/server/BlockServer.scala | 162 -- .../netty/server/BlockServerHandler.scala | 140 -- .../spark/storage/BlockFetcherIterator.scala | 138 +- .../apache/spark/storage/BlockManager.scala | 49 +- .../storage/BlockNotFoundException.scala | 21 - .../spark/storage/DiskBlockManager.scala | 13 +- core/src/test/resources/netty-test-file.txt | 1379 ----------------- .../netty/ServerClientIntegrationSuite.scala | 158 -- .../BlockFetchingClientHandlerSuite.scala | 87 -- .../server/BlockHeaderEncoderSuite.scala | 64 - .../server/BlockServerHandlerSuite.scala | 101 -- pom.xml | 2 +- 29 files changed, 667 insertions(+), 2770 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/network/netty/FileClient.scala rename core/src/main/scala/org/apache/spark/{storage/BlockDataProvider.scala => network/netty/FileClientChannelInitializer.scala} (65%) create mode 100644 core/src/main/scala/org/apache/spark/network/netty/FileClientHandler.scala create mode 100644 core/src/main/scala/org/apache/spark/network/netty/FileHeader.scala create mode 100644 core/src/main/scala/org/apache/spark/network/netty/FileServer.scala rename core/src/main/scala/org/apache/spark/network/netty/{server/BlockServerChannelInitializer.scala => FileServerChannelInitializer.scala} (58%) create mode 100644 core/src/main/scala/org/apache/spark/network/netty/FileServerHandler.scala delete mode 100644 core/src/main/scala/org/apache/spark/network/netty/NettyConfig.scala create mode 100644 core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala create mode 100644 core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala delete mode 100644 core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClient.scala delete mode 100644 core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClientFactory.scala delete mode 100644 core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClientHandler.scala delete mode 100644 core/src/main/scala/org/apache/spark/network/netty/client/LazyInitIterator.scala delete mode 100644 core/src/main/scala/org/apache/spark/network/netty/client/ReferenceCountedBuffer.scala delete mode 100644 core/src/main/scala/org/apache/spark/network/netty/server/BlockHeader.scala delete mode 100644 core/src/main/scala/org/apache/spark/network/netty/server/BlockHeaderEncoder.scala delete mode 100644 core/src/main/scala/org/apache/spark/network/netty/server/BlockServer.scala delete mode 100644 core/src/main/scala/org/apache/spark/network/netty/server/BlockServerHandler.scala delete mode 100644 core/src/main/scala/org/apache/spark/storage/BlockNotFoundException.scala delete mode 100644 core/src/test/resources/netty-test-file.txt delete mode 100644 core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala delete mode 100644 core/src/test/scala/org/apache/spark/network/netty/client/BlockFetchingClientHandlerSuite.scala delete mode 100644 core/src/test/scala/org/apache/spark/network/netty/server/BlockHeaderEncoderSuite.scala delete mode 100644 core/src/test/scala/org/apache/spark/network/netty/server/BlockServerHandlerSuite.scala diff --git a/core/src/main/scala/org/apache/spark/network/netty/FileClient.scala b/core/src/main/scala/org/apache/spark/network/netty/FileClient.scala new file mode 100644 index 0000000000000..c6d35f73db545 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/netty/FileClient.scala @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty + +import java.util.concurrent.TimeUnit + +import io.netty.bootstrap.Bootstrap +import io.netty.channel.{Channel, ChannelOption, EventLoopGroup} +import io.netty.channel.oio.OioEventLoopGroup +import io.netty.channel.socket.oio.OioSocketChannel + +import org.apache.spark.Logging + +class FileClient(handler: FileClientHandler, connectTimeout: Int) extends Logging { + + private var channel: Channel = _ + private var bootstrap: Bootstrap = _ + private var group: EventLoopGroup = _ + private val sendTimeout = 60 + + def init(): Unit = { + group = new OioEventLoopGroup + bootstrap = new Bootstrap + bootstrap.group(group) + .channel(classOf[OioSocketChannel]) + .option(ChannelOption.SO_KEEPALIVE, java.lang.Boolean.TRUE) + .option(ChannelOption.TCP_NODELAY, java.lang.Boolean.TRUE) + .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, Integer.valueOf(connectTimeout)) + .handler(new FileClientChannelInitializer(handler)) + } + + def connect(host: String, port: Int) { + try { + channel = bootstrap.connect(host, port).sync().channel() + } catch { + case e: InterruptedException => + logWarning("FileClient interrupted while trying to connect", e) + close() + } + } + + def waitForClose(): Unit = { + try { + channel.closeFuture.sync() + } catch { + case e: InterruptedException => + logWarning("FileClient interrupted", e) + } + } + + def sendRequest(file: String): Unit = { + try { + val bSent = channel.writeAndFlush(file + "\r\n").await(sendTimeout, TimeUnit.SECONDS) + if (!bSent) { + throw new RuntimeException("Failed to send") + } + } catch { + case e: InterruptedException => + logError("Error", e) + } + } + + def close(): Unit = { + if (group != null) { + group.shutdownGracefully() + group = null + bootstrap = null + } + } +} diff --git a/core/src/main/scala/org/apache/spark/storage/BlockDataProvider.scala b/core/src/main/scala/org/apache/spark/network/netty/FileClientChannelInitializer.scala similarity index 65% rename from core/src/main/scala/org/apache/spark/storage/BlockDataProvider.scala rename to core/src/main/scala/org/apache/spark/network/netty/FileClientChannelInitializer.scala index 5b6d086630834..f4261c13f70a8 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockDataProvider.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/FileClientChannelInitializer.scala @@ -15,18 +15,17 @@ * limitations under the License. */ -package org.apache.spark.storage +package org.apache.spark.network.netty -import java.nio.ByteBuffer +import io.netty.channel.ChannelInitializer +import io.netty.channel.socket.SocketChannel +import io.netty.handler.codec.string.StringEncoder -/** - * An interface for providing data for blocks. - * - * getBlockData returns either a FileSegment (for zero-copy send), or a ByteBuffer. - * - * Aside from unit tests, [[BlockManager]] is the main class that implements this. - */ -private[spark] trait BlockDataProvider { - def getBlockData(blockId: String): Either[FileSegment, ByteBuffer] +class FileClientChannelInitializer(handler: FileClientHandler) + extends ChannelInitializer[SocketChannel] { + + def initChannel(channel: SocketChannel) { + channel.pipeline.addLast("encoder", new StringEncoder).addLast("handler", handler) + } } diff --git a/core/src/main/scala/org/apache/spark/network/netty/FileClientHandler.scala b/core/src/main/scala/org/apache/spark/network/netty/FileClientHandler.scala new file mode 100644 index 0000000000000..017302ec7d33d --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/netty/FileClientHandler.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty + +import io.netty.buffer.ByteBuf +import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler} + +import org.apache.spark.storage.BlockId + + +abstract class FileClientHandler extends SimpleChannelInboundHandler[ByteBuf] { + + private var currentHeader: FileHeader = null + + @volatile + private var handlerCalled: Boolean = false + + def isComplete: Boolean = handlerCalled + + def handle(ctx: ChannelHandlerContext, in: ByteBuf, header: FileHeader) + + def handleError(blockId: BlockId) + + override def channelRead0(ctx: ChannelHandlerContext, in: ByteBuf) { + if (currentHeader == null && in.readableBytes >= FileHeader.HEADER_SIZE) { + currentHeader = FileHeader.create(in.readBytes(FileHeader.HEADER_SIZE)) + } + if (in.readableBytes >= currentHeader.fileLen) { + handle(ctx, in, currentHeader) + handlerCalled = true + currentHeader = null + ctx.close() + } + } +} diff --git a/core/src/main/scala/org/apache/spark/network/netty/FileHeader.scala b/core/src/main/scala/org/apache/spark/network/netty/FileHeader.scala new file mode 100644 index 0000000000000..607e560ff277f --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/netty/FileHeader.scala @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty + +import io.netty.buffer._ + +import org.apache.spark.Logging +import org.apache.spark.storage.{BlockId, TestBlockId} + +private[spark] class FileHeader ( + val fileLen: Int, + val blockId: BlockId) extends Logging { + + lazy val buffer: ByteBuf = { + val buf = Unpooled.buffer() + buf.capacity(FileHeader.HEADER_SIZE) + buf.writeInt(fileLen) + buf.writeInt(blockId.name.length) + blockId.name.foreach((x: Char) => buf.writeByte(x)) + // padding the rest of header + if (FileHeader.HEADER_SIZE - buf.readableBytes > 0 ) { + buf.writeZero(FileHeader.HEADER_SIZE - buf.readableBytes) + } else { + throw new Exception("too long header " + buf.readableBytes) + logInfo("too long header") + } + buf + } + +} + +private[spark] object FileHeader { + + val HEADER_SIZE = 40 + + def getFileLenOffset = 0 + def getFileLenSize = Integer.SIZE/8 + + def create(buf: ByteBuf): FileHeader = { + val length = buf.readInt + val idLength = buf.readInt + val idBuilder = new StringBuilder(idLength) + for (i <- 1 to idLength) { + idBuilder += buf.readByte().asInstanceOf[Char] + } + val blockId = BlockId(idBuilder.toString()) + new FileHeader(length, blockId) + } + + def main(args:Array[String]) { + val header = new FileHeader(25, TestBlockId("my_block")) + val buf = header.buffer + val newHeader = FileHeader.create(buf) + System.out.println("id=" + newHeader.blockId + ",size=" + newHeader.fileLen) + } +} diff --git a/core/src/main/scala/org/apache/spark/network/netty/FileServer.scala b/core/src/main/scala/org/apache/spark/network/netty/FileServer.scala new file mode 100644 index 0000000000000..dff77950659af --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/netty/FileServer.scala @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty + +import java.net.InetSocketAddress + +import io.netty.bootstrap.ServerBootstrap +import io.netty.channel.{ChannelFuture, ChannelOption, EventLoopGroup} +import io.netty.channel.oio.OioEventLoopGroup +import io.netty.channel.socket.oio.OioServerSocketChannel + +import org.apache.spark.Logging + +/** + * Server that accept the path of a file an echo back its content. + */ +class FileServer(pResolver: PathResolver, private var port: Int) extends Logging { + + private val addr: InetSocketAddress = new InetSocketAddress(port) + private var bossGroup: EventLoopGroup = new OioEventLoopGroup + private var workerGroup: EventLoopGroup = new OioEventLoopGroup + + private var channelFuture: ChannelFuture = { + val bootstrap = new ServerBootstrap + bootstrap.group(bossGroup, workerGroup) + .channel(classOf[OioServerSocketChannel]) + .option(ChannelOption.SO_BACKLOG, java.lang.Integer.valueOf(100)) + .option(ChannelOption.SO_RCVBUF, java.lang.Integer.valueOf(1500)) + .childHandler(new FileServerChannelInitializer(pResolver)) + bootstrap.bind(addr) + } + + try { + val boundAddress = channelFuture.sync.channel.localAddress.asInstanceOf[InetSocketAddress] + port = boundAddress.getPort + } catch { + case ie: InterruptedException => + port = 0 + } + + /** Start the file server asynchronously in a new thread. */ + def start(): Unit = { + val blockingThread: Thread = new Thread { + override def run(): Unit = { + try { + channelFuture.channel.closeFuture.sync + logInfo("FileServer exiting") + } catch { + case e: InterruptedException => + logError("File server start got interrupted", e) + } + // NOTE: bootstrap is shutdown in stop() + } + } + blockingThread.setDaemon(true) + blockingThread.start() + } + + def getPort: Int = port + + def stop(): Unit = { + if (channelFuture != null) { + channelFuture.channel().close().awaitUninterruptibly() + channelFuture = null + } + if (bossGroup != null) { + bossGroup.shutdownGracefully() + bossGroup = null + } + if (workerGroup != null) { + workerGroup.shutdownGracefully() + workerGroup = null + } + } +} + diff --git a/core/src/main/scala/org/apache/spark/network/netty/server/BlockServerChannelInitializer.scala b/core/src/main/scala/org/apache/spark/network/netty/FileServerChannelInitializer.scala similarity index 58% rename from core/src/main/scala/org/apache/spark/network/netty/server/BlockServerChannelInitializer.scala rename to core/src/main/scala/org/apache/spark/network/netty/FileServerChannelInitializer.scala index cc70bd0c5c477..aaa2f913d0269 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/server/BlockServerChannelInitializer.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/FileServerChannelInitializer.scala @@ -15,26 +15,20 @@ * limitations under the License. */ -package org.apache.spark.network.netty.server +package org.apache.spark.network.netty import io.netty.channel.ChannelInitializer import io.netty.channel.socket.SocketChannel -import io.netty.handler.codec.LineBasedFrameDecoder +import io.netty.handler.codec.{DelimiterBasedFrameDecoder, Delimiters} import io.netty.handler.codec.string.StringDecoder -import io.netty.util.CharsetUtil -import org.apache.spark.storage.BlockDataProvider - -/** Channel initializer that sets up the pipeline for the BlockServer. */ -private[netty] -class BlockServerChannelInitializer(dataProvider: BlockDataProvider) +class FileServerChannelInitializer(pResolver: PathResolver) extends ChannelInitializer[SocketChannel] { - override def initChannel(ch: SocketChannel): Unit = { - ch.pipeline - .addLast("frameDecoder", new LineBasedFrameDecoder(1024)) // max block id length 1024 - .addLast("stringDecoder", new StringDecoder(CharsetUtil.UTF_8)) - .addLast("blockHeaderEncoder", new BlockHeaderEncoder) - .addLast("handler", new BlockServerHandler(dataProvider)) + override def initChannel(channel: SocketChannel): Unit = { + channel.pipeline + .addLast("framer", new DelimiterBasedFrameDecoder(8192, Delimiters.lineDelimiter : _*)) + .addLast("stringDecoder", new StringDecoder) + .addLast("handler", new FileServerHandler(pResolver)) } } diff --git a/core/src/main/scala/org/apache/spark/network/netty/FileServerHandler.scala b/core/src/main/scala/org/apache/spark/network/netty/FileServerHandler.scala new file mode 100644 index 0000000000000..96f60b2883ad9 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/netty/FileServerHandler.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty + +import java.io.FileInputStream + +import io.netty.channel.{DefaultFileRegion, ChannelHandlerContext, SimpleChannelInboundHandler} + +import org.apache.spark.Logging +import org.apache.spark.storage.{BlockId, FileSegment} + + +class FileServerHandler(pResolver: PathResolver) + extends SimpleChannelInboundHandler[String] with Logging { + + override def channelRead0(ctx: ChannelHandlerContext, blockIdString: String): Unit = { + val blockId: BlockId = BlockId(blockIdString) + val fileSegment: FileSegment = pResolver.getBlockLocation(blockId) + if (fileSegment == null) { + return + } + val file = fileSegment.file + if (file.exists) { + if (!file.isFile) { + ctx.write(new FileHeader(0, blockId).buffer) + ctx.flush() + return + } + val length: Long = fileSegment.length + if (length > Integer.MAX_VALUE || length <= 0) { + ctx.write(new FileHeader(0, blockId).buffer) + ctx.flush() + return + } + ctx.write(new FileHeader(length.toInt, blockId).buffer) + try { + val channel = new FileInputStream(file).getChannel + ctx.write(new DefaultFileRegion(channel, fileSegment.offset, fileSegment.length)) + } catch { + case e: Exception => + logError("Exception: ", e) + } + } else { + ctx.write(new FileHeader(0, blockId).buffer) + } + ctx.flush() + } + + override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = { + logError("Exception: ", cause) + ctx.close() + } +} diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyConfig.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyConfig.scala deleted file mode 100644 index b5870152c5a64..0000000000000 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyConfig.scala +++ /dev/null @@ -1,59 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty - -import org.apache.spark.SparkConf - -/** - * A central location that tracks all the settings we exposed to users. - */ -private[spark] -class NettyConfig(conf: SparkConf) { - - /** Port the server listens on. Default to a random port. */ - private[netty] val serverPort = conf.getInt("spark.shuffle.io.port", 0) - - /** IO mode: nio, oio, epoll, or auto (try epoll first and then nio). */ - private[netty] val ioMode = conf.get("spark.shuffle.io.mode", "nio").toLowerCase - - /** Connect timeout in secs. Default 60 secs. */ - private[netty] val connectTimeoutMs = conf.getInt("spark.shuffle.io.connectionTimeout", 60) * 1000 - - /** - * Percentage of the desired amount of time spent for I/O in the child event loops. - * Only applicable in nio and epoll. - */ - private[netty] val ioRatio = conf.getInt("spark.shuffle.io.netty.ioRatio", 80) - - /** Requested maximum length of the queue of incoming connections. */ - private[netty] val backLog: Option[Int] = conf.getOption("spark.shuffle.io.backLog").map(_.toInt) - - /** - * Receive buffer size (SO_RCVBUF). - * Note: the optimal size for receive buffer and send buffer should be - * latency * network_bandwidth. - * Assuming latency = 1ms, network_bandwidth = 10Gbps - * buffer size should be ~ 1.25MB - */ - private[netty] val receiveBuf: Option[Int] = - conf.getOption("spark.shuffle.io.sendBuffer").map(_.toInt) - - /** Send buffer size (SO_SNDBUF). */ - private[netty] val sendBuf: Option[Int] = - conf.getOption("spark.shuffle.io.sendBuffer").map(_.toInt) -} diff --git a/core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala b/core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala new file mode 100644 index 0000000000000..e7b2855e1ec91 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty + +import java.util.concurrent.Executors + +import scala.collection.JavaConverters._ + +import io.netty.buffer.ByteBuf +import io.netty.channel.ChannelHandlerContext +import io.netty.util.CharsetUtil + +import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.network.ConnectionManagerId +import org.apache.spark.storage.BlockId + +private[spark] class ShuffleCopier(conf: SparkConf) extends Logging { + + def getBlock(host: String, port: Int, blockId: BlockId, + resultCollectCallback: (BlockId, Long, ByteBuf) => Unit) { + + val handler = new ShuffleCopier.ShuffleClientHandler(resultCollectCallback) + val connectTimeout = conf.getInt("spark.shuffle.netty.connect.timeout", 60000) + val fc = new FileClient(handler, connectTimeout) + + try { + fc.init() + fc.connect(host, port) + fc.sendRequest(blockId.name) + fc.waitForClose() + fc.close() + } catch { + // Handle any socket-related exceptions in FileClient + case e: Exception => { + logError("Shuffle copy of block " + blockId + " from " + host + ":" + port + " failed", e) + handler.handleError(blockId) + } + } + } + + def getBlock(cmId: ConnectionManagerId, blockId: BlockId, + resultCollectCallback: (BlockId, Long, ByteBuf) => Unit) { + getBlock(cmId.host, cmId.port, blockId, resultCollectCallback) + } + + def getBlocks(cmId: ConnectionManagerId, + blocks: Seq[(BlockId, Long)], + resultCollectCallback: (BlockId, Long, ByteBuf) => Unit) { + + for ((blockId, size) <- blocks) { + getBlock(cmId, blockId, resultCollectCallback) + } + } +} + + +private[spark] object ShuffleCopier extends Logging { + + private class ShuffleClientHandler(resultCollectCallBack: (BlockId, Long, ByteBuf) => Unit) + extends FileClientHandler with Logging { + + override def handle(ctx: ChannelHandlerContext, in: ByteBuf, header: FileHeader) { + logDebug("Received Block: " + header.blockId + " (" + header.fileLen + "B)") + resultCollectCallBack(header.blockId, header.fileLen.toLong, in.readBytes(header.fileLen)) + } + + override def handleError(blockId: BlockId) { + if (!isComplete) { + resultCollectCallBack(blockId, -1, null) + } + } + } + + def echoResultCollectCallBack(blockId: BlockId, size: Long, content: ByteBuf) { + if (size != -1) { + logInfo("File: " + blockId + " content is : \" " + content.toString(CharsetUtil.UTF_8) + "\"") + } + } + + def main(args: Array[String]) { + if (args.length < 3) { + System.err.println("Usage: ShuffleCopier ") + System.exit(1) + } + val host = args(0) + val port = args(1).toInt + val blockId = BlockId(args(2)) + val threads = if (args.length > 3) args(3).toInt else 10 + + val copiers = Executors.newFixedThreadPool(80) + val tasks = (for (i <- Range(0, threads)) yield { + Executors.callable(new Runnable() { + def run() { + val copier = new ShuffleCopier(new SparkConf) + copier.getBlock(host, port, blockId, echoResultCollectCallBack) + } + }) + }).asJava + copiers.invokeAll(tasks) + copiers.shutdown() + System.exit(0) + } +} diff --git a/core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala b/core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala new file mode 100644 index 0000000000000..95958e30f7eeb --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty + +import java.io.File + +import org.apache.spark.Logging +import org.apache.spark.util.Utils +import org.apache.spark.storage.{BlockId, FileSegment} + +private[spark] class ShuffleSender(portIn: Int, val pResolver: PathResolver) extends Logging { + + val server = new FileServer(pResolver, portIn) + server.start() + + def stop() { + server.stop() + } + + def port: Int = server.getPort +} + + +/** + * An application for testing the shuffle sender as a standalone program. + */ +private[spark] object ShuffleSender { + + def main(args: Array[String]) { + if (args.length < 3) { + System.err.println( + "Usage: ShuffleSender ") + System.exit(1) + } + + val port = args(0).toInt + val subDirsPerLocalDir = args(1).toInt + val localDirs = args.drop(2).map(new File(_)) + + val pResovler = new PathResolver { + override def getBlockLocation(blockId: BlockId): FileSegment = { + if (!blockId.isShuffle) { + throw new Exception("Block " + blockId + " is not a shuffle block") + } + // Figure out which local directory it hashes to, and which subdirectory in that + val hash = Utils.nonNegativeHash(blockId) + val dirId = hash % localDirs.length + val subDirId = (hash / localDirs.length) % subDirsPerLocalDir + val subDir = new File(localDirs(dirId), "%02x".format(subDirId)) + val file = new File(subDir, blockId.name) + new FileSegment(file, 0, file.length()) + } + } + val sender = new ShuffleSender(port, pResovler) + } +} diff --git a/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClient.scala b/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClient.scala deleted file mode 100644 index 9fed11b75c342..0000000000000 --- a/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClient.scala +++ /dev/null @@ -1,135 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty.client - -import java.util.concurrent.TimeoutException - -import io.netty.bootstrap.Bootstrap -import io.netty.buffer.PooledByteBufAllocator -import io.netty.channel.socket.SocketChannel -import io.netty.channel.{ChannelFutureListener, ChannelFuture, ChannelInitializer, ChannelOption} -import io.netty.handler.codec.LengthFieldBasedFrameDecoder -import io.netty.handler.codec.string.StringEncoder -import io.netty.util.CharsetUtil - -import org.apache.spark.Logging - -/** - * Client for fetching data blocks from [[org.apache.spark.network.netty.server.BlockServer]]. - * Use [[BlockFetchingClientFactory]] to instantiate this client. - * - * The constructor blocks until a connection is successfully established. - * - * See [[org.apache.spark.network.netty.server.BlockServer]] for client/server protocol. - * - * Concurrency: [[BlockFetchingClient]] is not thread safe and should not be shared. - */ -@throws[TimeoutException] -private[spark] -class BlockFetchingClient(factory: BlockFetchingClientFactory, hostname: String, port: Int) - extends Logging { - - val handler = new BlockFetchingClientHandler - - /** Netty Bootstrap for creating the TCP connection. */ - private val bootstrap: Bootstrap = { - val b = new Bootstrap - b.group(factory.workerGroup) - .channel(factory.socketChannelClass) - // Use pooled buffers to reduce temporary buffer allocation - .option(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT) - // Disable Nagle's Algorithm since we don't want packets to wait - .option(ChannelOption.TCP_NODELAY, java.lang.Boolean.TRUE) - .option(ChannelOption.SO_KEEPALIVE, java.lang.Boolean.TRUE) - .option[Integer](ChannelOption.CONNECT_TIMEOUT_MILLIS, factory.conf.connectTimeoutMs) - - b.handler(new ChannelInitializer[SocketChannel] { - override def initChannel(ch: SocketChannel): Unit = { - ch.pipeline - .addLast("encoder", new StringEncoder(CharsetUtil.UTF_8)) - // maxFrameLength = 2G, lengthFieldOffset = 0, lengthFieldLength = 4 - .addLast("framedLengthDecoder", new LengthFieldBasedFrameDecoder(Int.MaxValue, 0, 4)) - .addLast("handler", handler) - } - }) - b - } - - /** Netty ChannelFuture for the connection. */ - private val cf: ChannelFuture = bootstrap.connect(hostname, port) - if (!cf.awaitUninterruptibly(factory.conf.connectTimeoutMs)) { - throw new TimeoutException( - s"Connecting to $hostname:$port timed out (${factory.conf.connectTimeoutMs} ms)") - } - - /** - * Ask the remote server for a sequence of blocks, and execute the callback. - * - * Note that this is asynchronous and returns immediately. Upstream caller should throttle the - * rate of fetching; otherwise we could run out of memory. - * - * @param blockIds sequence of block ids to fetch. - * @param blockFetchSuccessCallback callback function when a block is successfully fetched. - * First argument is the block id, and second argument is the - * raw data in a ByteBuffer. - * @param blockFetchFailureCallback callback function when we failed to fetch any of the blocks. - * First argument is the block id, and second argument is the - * error message. - */ - def fetchBlocks( - blockIds: Seq[String], - blockFetchSuccessCallback: (String, ReferenceCountedBuffer) => Unit, - blockFetchFailureCallback: (String, String) => Unit): Unit = { - // It's best to limit the number of "write" calls since it needs to traverse the whole pipeline. - // It's also best to limit the number of "flush" calls since it requires system calls. - // Let's concatenate the string and then call writeAndFlush once. - // This is also why this implementation might be more efficient than multiple, separate - // fetch block calls. - var startTime: Long = 0 - logTrace { - startTime = System.nanoTime - s"Sending request $blockIds to $hostname:$port" - } - - // TODO: This is not the most elegant way to handle this ... - handler.blockFetchSuccessCallback = blockFetchSuccessCallback - handler.blockFetchFailureCallback = blockFetchFailureCallback - - val writeFuture = cf.channel().writeAndFlush(blockIds.mkString("\n") + "\n") - writeFuture.addListener(new ChannelFutureListener { - override def operationComplete(future: ChannelFuture): Unit = { - if (future.isSuccess) { - logTrace { - val timeTaken = (System.nanoTime - startTime).toDouble / 1000000 - s"Sending request $blockIds to $hostname:$port took $timeTaken ms" - } - } else { - // Fail all blocks. - logError(s"Failed to send request $blockIds to $hostname:$port", future.cause) - blockIds.foreach(blockFetchFailureCallback(_, future.cause.getMessage)) - } - } - }) - } - - def waitForClose(): Unit = { - cf.channel().closeFuture().sync() - } - - def close(): Unit = cf.channel().close() -} diff --git a/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClientFactory.scala b/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClientFactory.scala deleted file mode 100644 index 2b28402c52b49..0000000000000 --- a/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClientFactory.scala +++ /dev/null @@ -1,99 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty.client - -import io.netty.channel.epoll.{EpollEventLoopGroup, EpollSocketChannel} -import io.netty.channel.nio.NioEventLoopGroup -import io.netty.channel.oio.OioEventLoopGroup -import io.netty.channel.socket.nio.NioSocketChannel -import io.netty.channel.socket.oio.OioSocketChannel -import io.netty.channel.{EventLoopGroup, Channel} - -import org.apache.spark.SparkConf -import org.apache.spark.network.netty.NettyConfig -import org.apache.spark.util.Utils - -/** - * Factory for creating [[BlockFetchingClient]] by using createClient. This factory reuses - * the worker thread pool for Netty. - * - * Concurrency: createClient is safe to be called from multiple threads concurrently. - */ -private[spark] -class BlockFetchingClientFactory(val conf: NettyConfig) { - - def this(sparkConf: SparkConf) = this(new NettyConfig(sparkConf)) - - /** A thread factory so the threads are named (for debugging). */ - val threadFactory = Utils.namedThreadFactory("spark-shuffle-client") - - /** The following two are instantiated by the [[init]] method, depending ioMode. */ - var socketChannelClass: Class[_ <: Channel] = _ - var workerGroup: EventLoopGroup = _ - - init() - - /** Initialize [[socketChannelClass]] and [[workerGroup]] based on ioMode. */ - private def init(): Unit = { - def initOio(): Unit = { - socketChannelClass = classOf[OioSocketChannel] - workerGroup = new OioEventLoopGroup(0, threadFactory) - } - def initNio(): Unit = { - socketChannelClass = classOf[NioSocketChannel] - workerGroup = new NioEventLoopGroup(0, threadFactory) - } - def initEpoll(): Unit = { - socketChannelClass = classOf[EpollSocketChannel] - workerGroup = new EpollEventLoopGroup(0, threadFactory) - } - - conf.ioMode match { - case "nio" => initNio() - case "oio" => initOio() - case "epoll" => initEpoll() - case "auto" => - // For auto mode, first try epoll (only available on Linux), then nio. - try { - initEpoll() - } catch { - // TODO: Should we log the throwable? But that always happen on non-Linux systems. - // Perhaps the right thing to do is to check whether the system is Linux, and then only - // call initEpoll on Linux. - case e: Throwable => initNio() - } - } - } - - /** - * Create a new BlockFetchingClient connecting to the given remote host / port. - * - * This blocks until a connection is successfully established. - * - * Concurrency: This method is safe to call from multiple threads. - */ - def createClient(remoteHost: String, remotePort: Int): BlockFetchingClient = { - new BlockFetchingClient(this, remoteHost, remotePort) - } - - def stop(): Unit = { - if (workerGroup != null) { - workerGroup.shutdownGracefully() - } - } -} diff --git a/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClientHandler.scala b/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClientHandler.scala deleted file mode 100644 index a1dbf6102c080..0000000000000 --- a/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClientHandler.scala +++ /dev/null @@ -1,63 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty.client - -import io.netty.buffer.ByteBuf -import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler} - -import org.apache.spark.Logging - - -/** - * Handler that processes server responses. It uses the protocol documented in - * [[org.apache.spark.network.netty.server.BlockServer]]. - */ -private[client] -class BlockFetchingClientHandler extends SimpleChannelInboundHandler[ByteBuf] with Logging { - - var blockFetchSuccessCallback: (String, ReferenceCountedBuffer) => Unit = _ - var blockFetchFailureCallback: (String, String) => Unit = _ - - override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = { - logError(s"Exception in connection from ${ctx.channel.remoteAddress}", cause) - ctx.close() - } - - override def channelRead0(ctx: ChannelHandlerContext, in: ByteBuf) { - val totalLen = in.readInt() - val blockIdLen = in.readInt() - val blockIdBytes = new Array[Byte](math.abs(blockIdLen)) - in.readBytes(blockIdBytes) - val blockId = new String(blockIdBytes) - val blockSize = totalLen - math.abs(blockIdLen) - 4 - - def server = ctx.channel.remoteAddress.toString - - // blockIdLen is negative when it is an error message. - if (blockIdLen < 0) { - val errorMessageBytes = new Array[Byte](blockSize) - in.readBytes(errorMessageBytes) - val errorMsg = new String(errorMessageBytes) - logTrace(s"Received block $blockId ($blockSize B) with error $errorMsg from $server") - blockFetchFailureCallback(blockId, errorMsg) - } else { - logTrace(s"Received block $blockId ($blockSize B) from $server") - blockFetchSuccessCallback(blockId, new ReferenceCountedBuffer(in)) - } - } -} diff --git a/core/src/main/scala/org/apache/spark/network/netty/client/LazyInitIterator.scala b/core/src/main/scala/org/apache/spark/network/netty/client/LazyInitIterator.scala deleted file mode 100644 index 9740ee64d1f2d..0000000000000 --- a/core/src/main/scala/org/apache/spark/network/netty/client/LazyInitIterator.scala +++ /dev/null @@ -1,44 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty.client - -/** - * A simple iterator that lazily initializes the underlying iterator. - * - * The use case is that sometimes we might have many iterators open at the same time, and each of - * the iterator might initialize its own buffer (e.g. decompression buffer, deserialization buffer). - * This could lead to too many buffers open. If this iterator is used, we lazily initialize those - * buffers. - */ -private[spark] -class LazyInitIterator(createIterator: => Iterator[Any]) extends Iterator[Any] { - - lazy val proxy = createIterator - - override def hasNext: Boolean = { - val gotNext = proxy.hasNext - if (!gotNext) { - close() - } - gotNext - } - - override def next(): Any = proxy.next() - - def close(): Unit = Unit -} diff --git a/core/src/main/scala/org/apache/spark/network/netty/client/ReferenceCountedBuffer.scala b/core/src/main/scala/org/apache/spark/network/netty/client/ReferenceCountedBuffer.scala deleted file mode 100644 index ea1abf5eccc26..0000000000000 --- a/core/src/main/scala/org/apache/spark/network/netty/client/ReferenceCountedBuffer.scala +++ /dev/null @@ -1,47 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty.client - -import java.io.InputStream -import java.nio.ByteBuffer - -import io.netty.buffer.{ByteBuf, ByteBufInputStream} - - -/** - * A buffer abstraction based on Netty's ByteBuf so we don't expose Netty. - * This is a Scala value class. - * - * The buffer's life cycle is NOT managed by the JVM, and thus requiring explicit declaration of - * reference by the retain method and release method. - */ -private[spark] -class ReferenceCountedBuffer(val underlying: ByteBuf) extends AnyVal { - - /** Return the nio ByteBuffer view of the underlying buffer. */ - def byteBuffer(): ByteBuffer = underlying.nioBuffer - - /** Creates a new input stream that starts from the current position of the buffer. */ - def inputStream(): InputStream = new ByteBufInputStream(underlying) - - /** Increment the reference counter by one. */ - def retain(): Unit = underlying.retain() - - /** Decrement the reference counter by one and release the buffer if the ref count is 0. */ - def release(): Unit = underlying.release() -} diff --git a/core/src/main/scala/org/apache/spark/network/netty/server/BlockHeader.scala b/core/src/main/scala/org/apache/spark/network/netty/server/BlockHeader.scala deleted file mode 100644 index 162e9cc6828d4..0000000000000 --- a/core/src/main/scala/org/apache/spark/network/netty/server/BlockHeader.scala +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty.server - -/** - * Header describing a block. This is used only in the server pipeline. - * - * [[BlockServerHandler]] creates this, and [[BlockHeaderEncoder]] encodes it. - * - * @param blockSize length of the block content, excluding the length itself. - * If positive, this is the header for a block (not part of the header). - * If negative, this is the header and content for an error message. - * @param blockId block id - * @param error some error message from reading the block - */ -private[server] -class BlockHeader(val blockSize: Int, val blockId: String, val error: Option[String] = None) diff --git a/core/src/main/scala/org/apache/spark/network/netty/server/BlockHeaderEncoder.scala b/core/src/main/scala/org/apache/spark/network/netty/server/BlockHeaderEncoder.scala deleted file mode 100644 index 8e4dda4ef8595..0000000000000 --- a/core/src/main/scala/org/apache/spark/network/netty/server/BlockHeaderEncoder.scala +++ /dev/null @@ -1,47 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty.server - -import io.netty.buffer.ByteBuf -import io.netty.channel.ChannelHandlerContext -import io.netty.handler.codec.MessageToByteEncoder - -/** - * A simple encoder for BlockHeader. See [[BlockServer]] for the server to client protocol. - */ -private[server] -class BlockHeaderEncoder extends MessageToByteEncoder[BlockHeader] { - override def encode(ctx: ChannelHandlerContext, msg: BlockHeader, out: ByteBuf): Unit = { - // message = message length (4 bytes) + block id length (4 bytes) + block id + block data - // message length = block id length (4 bytes) + size of block id + size of block data - val blockIdBytes = msg.blockId.getBytes - msg.error match { - case Some(errorMsg) => - val errorBytes = errorMsg.getBytes - out.writeInt(4 + blockIdBytes.length + errorBytes.size) - out.writeInt(-blockIdBytes.length) // use negative block id length to represent errors - out.writeBytes(blockIdBytes) // next is blockId itself - out.writeBytes(errorBytes) // error message - case None => - out.writeInt(4 + blockIdBytes.length + msg.blockSize) - out.writeInt(blockIdBytes.length) // First 4 bytes is blockId length - out.writeBytes(blockIdBytes) // next is blockId itself - // msg of size blockSize will be written by ServerHandler - } - } -} diff --git a/core/src/main/scala/org/apache/spark/network/netty/server/BlockServer.scala b/core/src/main/scala/org/apache/spark/network/netty/server/BlockServer.scala deleted file mode 100644 index 7b2f9a8d4dfd0..0000000000000 --- a/core/src/main/scala/org/apache/spark/network/netty/server/BlockServer.scala +++ /dev/null @@ -1,162 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty.server - -import java.net.InetSocketAddress - -import io.netty.bootstrap.ServerBootstrap -import io.netty.buffer.PooledByteBufAllocator -import io.netty.channel.{ChannelFuture, ChannelInitializer, ChannelOption} -import io.netty.channel.epoll.{EpollEventLoopGroup, EpollServerSocketChannel} -import io.netty.channel.nio.NioEventLoopGroup -import io.netty.channel.oio.OioEventLoopGroup -import io.netty.channel.socket.SocketChannel -import io.netty.channel.socket.nio.NioServerSocketChannel -import io.netty.channel.socket.oio.OioServerSocketChannel -import io.netty.handler.codec.LineBasedFrameDecoder -import io.netty.handler.codec.string.StringDecoder -import io.netty.util.CharsetUtil - -import org.apache.spark.{Logging, SparkConf} -import org.apache.spark.network.netty.NettyConfig -import org.apache.spark.storage.BlockDataProvider -import org.apache.spark.util.Utils - - -/** - * Server for serving Spark data blocks. - * This should be used together with [[org.apache.spark.network.netty.client.BlockFetchingClient]]. - * - * Protocol for requesting blocks (client to server): - * One block id per line, e.g. to request 3 blocks: "block1\nblock2\nblock3\n" - * - * Protocol for sending blocks (server to client): - * frame-length (4 bytes), block-id-length (4 bytes), block-id, block-data. - * - * frame-length should not include the length of itself. - * If block-id-length is negative, then this is an error message rather than block-data. The real - * length is the absolute value of the frame-length. - * - */ -private[spark] -class BlockServer(conf: NettyConfig, dataProvider: BlockDataProvider) extends Logging { - - def this(sparkConf: SparkConf, dataProvider: BlockDataProvider) = { - this(new NettyConfig(sparkConf), dataProvider) - } - - def port: Int = _port - - def hostName: String = _hostName - - private var _port: Int = conf.serverPort - private var _hostName: String = "" - private var bootstrap: ServerBootstrap = _ - private var channelFuture: ChannelFuture = _ - - init() - - /** Initialize the server. */ - private def init(): Unit = { - bootstrap = new ServerBootstrap - val bossThreadFactory = Utils.namedThreadFactory("spark-shuffle-server-boss") - val workerThreadFactory = Utils.namedThreadFactory("spark-shuffle-server-worker") - - // Use only one thread to accept connections, and 2 * num_cores for worker. - def initNio(): Unit = { - val bossGroup = new NioEventLoopGroup(1, bossThreadFactory) - val workerGroup = new NioEventLoopGroup(0, workerThreadFactory) - workerGroup.setIoRatio(conf.ioRatio) - bootstrap.group(bossGroup, workerGroup).channel(classOf[NioServerSocketChannel]) - } - def initOio(): Unit = { - val bossGroup = new OioEventLoopGroup(1, bossThreadFactory) - val workerGroup = new OioEventLoopGroup(0, workerThreadFactory) - bootstrap.group(bossGroup, workerGroup).channel(classOf[OioServerSocketChannel]) - } - def initEpoll(): Unit = { - val bossGroup = new EpollEventLoopGroup(1, bossThreadFactory) - val workerGroup = new EpollEventLoopGroup(0, workerThreadFactory) - workerGroup.setIoRatio(conf.ioRatio) - bootstrap.group(bossGroup, workerGroup).channel(classOf[EpollServerSocketChannel]) - } - - conf.ioMode match { - case "nio" => initNio() - case "oio" => initOio() - case "epoll" => initEpoll() - case "auto" => - // For auto mode, first try epoll (only available on Linux), then nio. - try { - initEpoll() - } catch { - // TODO: Should we log the throwable? But that always happen on non-Linux systems. - // Perhaps the right thing to do is to check whether the system is Linux, and then only - // call initEpoll on Linux. - case e: Throwable => initNio() - } - } - - // Use pooled buffers to reduce temporary buffer allocation - bootstrap.option(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT) - bootstrap.childOption(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT) - - // Various (advanced) user-configured settings. - conf.backLog.foreach { backLog => - bootstrap.option[java.lang.Integer](ChannelOption.SO_BACKLOG, backLog) - } - conf.receiveBuf.foreach { receiveBuf => - bootstrap.option[java.lang.Integer](ChannelOption.SO_RCVBUF, receiveBuf) - } - conf.sendBuf.foreach { sendBuf => - bootstrap.option[java.lang.Integer](ChannelOption.SO_SNDBUF, sendBuf) - } - - bootstrap.childHandler(new ChannelInitializer[SocketChannel] { - override def initChannel(ch: SocketChannel): Unit = { - ch.pipeline - .addLast("frameDecoder", new LineBasedFrameDecoder(1024)) // max block id length 1024 - .addLast("stringDecoder", new StringDecoder(CharsetUtil.UTF_8)) - .addLast("blockHeaderEncoder", new BlockHeaderEncoder) - .addLast("handler", new BlockServerHandler(dataProvider)) - } - }) - - channelFuture = bootstrap.bind(new InetSocketAddress(_port)) - channelFuture.sync() - - val addr = channelFuture.channel.localAddress.asInstanceOf[InetSocketAddress] - _port = addr.getPort - _hostName = addr.getHostName - } - - /** Shutdown the server. */ - def stop(): Unit = { - if (channelFuture != null) { - channelFuture.channel().close().awaitUninterruptibly() - channelFuture = null - } - if (bootstrap != null && bootstrap.group() != null) { - bootstrap.group().shutdownGracefully() - } - if (bootstrap != null && bootstrap.childGroup() != null) { - bootstrap.childGroup().shutdownGracefully() - } - bootstrap = null - } -} diff --git a/core/src/main/scala/org/apache/spark/network/netty/server/BlockServerHandler.scala b/core/src/main/scala/org/apache/spark/network/netty/server/BlockServerHandler.scala deleted file mode 100644 index 40dd5e5d1a2ac..0000000000000 --- a/core/src/main/scala/org/apache/spark/network/netty/server/BlockServerHandler.scala +++ /dev/null @@ -1,140 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty.server - -import java.io.FileInputStream -import java.nio.ByteBuffer -import java.nio.channels.FileChannel - -import io.netty.buffer.Unpooled -import io.netty.channel._ - -import org.apache.spark.Logging -import org.apache.spark.storage.{FileSegment, BlockDataProvider} - - -/** - * A handler that processes requests from clients and writes block data back. - * - * The messages should have been processed by a LineBasedFrameDecoder and a StringDecoder first - * so channelRead0 is called once per line (i.e. per block id). - */ -private[server] -class BlockServerHandler(dataProvider: BlockDataProvider) - extends SimpleChannelInboundHandler[String] with Logging { - - override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = { - logError(s"Exception in connection from ${ctx.channel.remoteAddress}", cause) - ctx.close() - } - - override def channelRead0(ctx: ChannelHandlerContext, blockId: String): Unit = { - def client = ctx.channel.remoteAddress.toString - - // A helper function to send error message back to the client. - def respondWithError(error: String): Unit = { - ctx.writeAndFlush(new BlockHeader(-1, blockId, Some(error))).addListener( - new ChannelFutureListener { - override def operationComplete(future: ChannelFuture) { - if (!future.isSuccess) { - // TODO: Maybe log the success case as well. - logError(s"Error sending error back to $client", future.cause) - ctx.close() - } - } - } - ) - } - - def writeFileSegment(segment: FileSegment): Unit = { - // Send error message back if the block is too large. Even though we are capable of sending - // large (2G+) blocks, the receiving end cannot handle it so let's fail fast. - // Once we fixed the receiving end to be able to process large blocks, this should be removed. - // Also make sure we update BlockHeaderEncoder to support length > 2G. - - // See [[BlockHeaderEncoder]] for the way length is encoded. - if (segment.length + blockId.length + 4 > Int.MaxValue) { - respondWithError(s"Block $blockId size ($segment.length) greater than 2G") - return - } - - var fileChannel: FileChannel = null - try { - fileChannel = new FileInputStream(segment.file).getChannel - } catch { - case e: Exception => - logError( - s"Error opening channel for $blockId in ${segment.file} for request from $client", e) - respondWithError(e.getMessage) - } - - // Found the block. Send it back. - if (fileChannel != null) { - // Write the header and block data. In the case of failures, the listener on the block data - // write should close the connection. - ctx.write(new BlockHeader(segment.length.toInt, blockId)) - - val region = new DefaultFileRegion(fileChannel, segment.offset, segment.length) - ctx.writeAndFlush(region).addListener(new ChannelFutureListener { - override def operationComplete(future: ChannelFuture) { - if (future.isSuccess) { - logTrace(s"Sent block $blockId (${segment.length} B) back to $client") - } else { - logError(s"Error sending block $blockId to $client; closing connection", future.cause) - ctx.close() - } - } - }) - } - } - - def writeByteBuffer(buf: ByteBuffer): Unit = { - ctx.write(new BlockHeader(buf.remaining, blockId)) - ctx.writeAndFlush(Unpooled.wrappedBuffer(buf)).addListener(new ChannelFutureListener { - override def operationComplete(future: ChannelFuture) { - if (future.isSuccess) { - logTrace(s"Sent block $blockId (${buf.remaining} B) back to $client") - } else { - logError(s"Error sending block $blockId to $client; closing connection", future.cause) - ctx.close() - } - } - }) - } - - logTrace(s"Received request from $client to fetch block $blockId") - - var blockData: Either[FileSegment, ByteBuffer] = null - - // First make sure we can find the block. If not, send error back to the user. - try { - blockData = dataProvider.getBlockData(blockId) - } catch { - case e: Exception => - logError(s"Error opening block $blockId for request from $client", e) - respondWithError(e.getMessage) - return - } - - blockData match { - case Left(segment) => writeFileSegment(segment) - case Right(buf) => writeByteBuffer(buf) - } - - } // end of channelRead0 -} diff --git a/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala index 91c0f47d51d02..5f44f5f3197fd 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala @@ -18,17 +18,19 @@ package org.apache.spark.storage import java.util.concurrent.LinkedBlockingQueue -import org.apache.spark.network.netty.client.{LazyInitIterator, ReferenceCountedBuffer} import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashSet import scala.collection.mutable.Queue import scala.util.{Failure, Success} +import io.netty.buffer.ByteBuf + import org.apache.spark.{Logging, SparkException} import org.apache.spark.executor.ShuffleReadMetrics import org.apache.spark.network.BufferMessage import org.apache.spark.network.ConnectionManagerId +import org.apache.spark.network.netty.ShuffleCopier import org.apache.spark.serializer.Serializer import org.apache.spark.util.Utils @@ -52,28 +54,18 @@ trait BlockFetcherIterator extends Iterator[(BlockId, Option[Iterator[Any]])] wi private[storage] object BlockFetcherIterator { - /** - * A request to fetch blocks from a remote BlockManager. - * @param address remote BlockManager to fetch from. - * @param blocks Sequence of tuple, where the first element is the block id, - * and the second element is the estimated size, used to calculate bytesInFlight. - */ + // A request to fetch one or more blocks, complete with their sizes class FetchRequest(val address: BlockManagerId, val blocks: Seq[(BlockId, Long)]) { val size = blocks.map(_._2).sum } - /** - * Result of a fetch from a remote block. A failure is represented as size == -1. - * @param blockId block id - * @param size estimated size of the block, used to calculate bytesInFlight. - * Note that this is NOT the exact bytes. - * @param deserialize closure to return the result in the form of an Iterator. - */ + // A result of a fetch. Includes the block ID, size in bytes, and a function to deserialize + // the block (since we want all deserializaton to happen in the calling thread); can also + // represent a fetch failure if size == -1. class FetchResult(val blockId: BlockId, val size: Long, val deserialize: () => Iterator[Any]) { def failed: Boolean = size == -1 } - // TODO: Refactor this whole thing to make code more reusable. class BasicBlockFetcherIterator( private val blockManager: BlockManager, val blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])], @@ -103,10 +95,10 @@ object BlockFetcherIterator { // Queue of fetch requests to issue; we'll pull requests off this gradually to make sure that // the number of bytes in flight is limited to maxBytesInFlight - protected val fetchRequests = new Queue[FetchRequest] + private val fetchRequests = new Queue[FetchRequest] // Current bytes in flight from our requests - protected var bytesInFlight = 0L + private var bytesInFlight = 0L protected def sendRequest(req: FetchRequest) { logDebug("Sending request for %d blocks (%s) from %s".format( @@ -270,55 +262,77 @@ object BlockFetcherIterator { readMetrics: ShuffleReadMetrics) extends BasicBlockFetcherIterator(blockManager, blocksByAddress, serializer, readMetrics) { - override protected def sendRequest(req: FetchRequest) { - logDebug("Sending request for %d blocks (%s) from %s".format( - req.blocks.size, Utils.bytesToString(req.size), req.address.hostPort)) - val cmId = new ConnectionManagerId(req.address.host, req.address.port) + import blockManager._ - bytesInFlight += req.size - val sizeMap = req.blocks.toMap // so we can look up the size of each blockID - - // This could throw a TimeoutException. In that case we will just retry the task. - val client = blockManager.nettyBlockClientFactory.createClient( - cmId.host, req.address.nettyPort) - val blocks = req.blocks.map(_._1.toString) - - client.fetchBlocks( - blocks, - (blockId: String, refBuf: ReferenceCountedBuffer) => { - // Increment the reference count so the buffer won't be recycled. - // TODO: This could result in memory leaks when the task is stopped due to exception - // before the iterator is exhausted. - refBuf.retain() - val buf = refBuf.byteBuffer() - val blockSize = buf.remaining() - val bid = BlockId(blockId) - - // TODO: remove code duplication between here and BlockManager.dataDeserialization. - results.put(new FetchResult(bid, sizeMap(bid), () => { - def createIterator: Iterator[Any] = { - val stream = blockManager.wrapForCompression(bid, refBuf.inputStream()) - serializer.newInstance().deserializeStream(stream).asIterator - } - new LazyInitIterator(createIterator) { - // Release the buffer when we are done traversing it. - override def close(): Unit = refBuf.release() + val fetchRequestsSync = new LinkedBlockingQueue[FetchRequest] + + private def startCopiers(numCopiers: Int): List[_ <: Thread] = { + (for ( i <- Range(0,numCopiers) ) yield { + val copier = new Thread { + override def run(){ + try { + while(!isInterrupted && !fetchRequestsSync.isEmpty) { + sendRequest(fetchRequestsSync.take()) + } + } catch { + case x: InterruptedException => logInfo("Copier Interrupted") + // case _ => throw new SparkException("Exception Throw in Shuffle Copier") } - })) - - readMetrics.synchronized { - readMetrics.remoteBytesRead += blockSize - readMetrics.remoteBlocksFetched += 1 - } - logDebug("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime)) - }, - (blockId: String, errorMsg: String) => { - logError(s"Could not get block(s) from $cmId with error: $errorMsg") - for ((blockId, size) <- req.blocks) { - results.put(new FetchResult(blockId, -1, null)) } } - ) + copier.start + copier + }).toList + } + + // keep this to interrupt the threads when necessary + private def stopCopiers() { + for (copier <- copiers) { + copier.interrupt() + } + } + + override protected def sendRequest(req: FetchRequest) { + + def putResult(blockId: BlockId, blockSize: Long, blockData: ByteBuf) { + val fetchResult = new FetchResult(blockId, blockSize, + () => dataDeserialize(blockId, blockData.nioBuffer, serializer)) + results.put(fetchResult) + } + + logDebug("Sending request for %d blocks (%s) from %s".format( + req.blocks.size, Utils.bytesToString(req.size), req.address.host)) + val cmId = new ConnectionManagerId(req.address.host, req.address.nettyPort) + val cpier = new ShuffleCopier(blockManager.conf) + cpier.getBlocks(cmId, req.blocks, putResult) + logDebug("Sent request for remote blocks " + req.blocks + " from " + req.address.host ) + } + + private var copiers: List[_ <: Thread] = null + + override def initialize() { + // Split Local Remote Blocks and set numBlocksToFetch + val remoteRequests = splitLocalRemoteBlocks() + // Add the remote requests into our queue in a random order + for (request <- Utils.randomize(remoteRequests)) { + fetchRequestsSync.put(request) + } + + copiers = startCopiers(conf.getInt("spark.shuffle.copier.threads", 6)) + logInfo("Started " + fetchRequestsSync.size + " remote fetches in " + + Utils.getUsedTimeMs(startTime)) + + // Get Local Blocks + startTime = System.currentTimeMillis + getLocalBlocks() + logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms") + } + + override def next(): (BlockId, Option[Iterator[Any]]) = { + resultsGotten += 1 + val result = results.take() + // If all the results has been retrieved, copiers will exit automatically + (result.blockId, if (result.failed) None else Some(result.deserialize())) } } // End of NettyBlockFetcherIterator diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index e67676950b0ed..e8bbd298c631a 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -25,19 +25,16 @@ import scala.concurrent.{Await, Future} import scala.concurrent.duration._ import scala.util.Random -import akka.actor.{ActorSystem, Props} +import akka.actor.{ActorSystem, Cancellable, Props} import sun.nio.ch.DirectBuffer import org.apache.spark._ import org.apache.spark.executor._ import org.apache.spark.io.CompressionCodec import org.apache.spark.network._ -import org.apache.spark.network.netty.client.BlockFetchingClientFactory -import org.apache.spark.network.netty.server.BlockServer import org.apache.spark.serializer.Serializer import org.apache.spark.util._ - private[spark] sealed trait BlockValues private[spark] case class ByteBufferValues(buffer: ByteBuffer) extends BlockValues private[spark] case class IteratorValues(iterator: Iterator[Any]) extends BlockValues @@ -61,7 +58,7 @@ private[spark] class BlockManager( val conf: SparkConf, securityManager: SecurityManager, mapOutputTracker: MapOutputTracker) - extends BlockDataProvider with Logging { + extends Logging { private val port = conf.getInt("spark.blockManager.port", 0) val shuffleBlockManager = new ShuffleBlockManager(this) @@ -89,25 +86,13 @@ private[spark] class BlockManager( new TachyonStore(this, tachyonBlockManager) } - private val useNetty = conf.getBoolean("spark.shuffle.use.netty", false) - // If we use Netty for shuffle, start a new Netty-based shuffle sender service. - private[storage] val nettyBlockClientFactory: BlockFetchingClientFactory = { - if (useNetty) new BlockFetchingClientFactory(conf) else null + private val nettyPort: Int = { + val useNetty = conf.getBoolean("spark.shuffle.use.netty", false) + val nettyPortConfig = conf.getInt("spark.shuffle.sender.port", 0) + if (useNetty) diskBlockManager.startShuffleBlockSender(nettyPortConfig) else 0 } - private val nettyBlockServer: BlockServer = { - if (useNetty) { - val server = new BlockServer(conf, this) - logInfo(s"Created NettyBlockServer binding to port: ${server.port}") - server - } else { - null - } - } - - private val nettyPort: Int = if (useNetty) nettyBlockServer.port else 0 - val blockManagerId = BlockManagerId( executorId, connectionManager.id.host, connectionManager.id.port, nettyPort) @@ -231,20 +216,6 @@ private[spark] class BlockManager( } } - override def getBlockData(blockId: String): Either[FileSegment, ByteBuffer] = { - val bid = BlockId(blockId) - if (bid.isShuffle) { - Left(diskBlockManager.getBlockLocation(bid)) - } else { - val blockBytesOpt = doGetLocal(bid, asBlockResult = false).asInstanceOf[Option[ByteBuffer]] - if (blockBytesOpt.isDefined) { - Right(blockBytesOpt.get) - } else { - throw new BlockNotFoundException(blockId) - } - } - } - /** * Get the BlockStatus for the block identified by the given ID, if it exists. * NOTE: This is mainly for testing, and it doesn't fetch information from Tachyon. @@ -1090,14 +1061,6 @@ private[spark] class BlockManager( connectionManager.stop() shuffleBlockManager.stop() diskBlockManager.stop() - - if (nettyBlockClientFactory != null) { - nettyBlockClientFactory.stop() - } - if (nettyBlockServer != null) { - nettyBlockServer.stop() - } - actorSystem.stop(slaveActor) blockInfo.clear() memoryStore.clear() diff --git a/core/src/main/scala/org/apache/spark/storage/BlockNotFoundException.scala b/core/src/main/scala/org/apache/spark/storage/BlockNotFoundException.scala deleted file mode 100644 index 9ef453605f4f1..0000000000000 --- a/core/src/main/scala/org/apache/spark/storage/BlockNotFoundException.scala +++ /dev/null @@ -1,21 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.storage - - -class BlockNotFoundException(blockId: String) extends Exception(s"Block $blockId not found") diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala index f3da816389581..4d66ccea211fa 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala @@ -23,7 +23,7 @@ import java.util.{Date, Random, UUID} import org.apache.spark.{SparkEnv, Logging} import org.apache.spark.executor.ExecutorExitCode -import org.apache.spark.network.netty.PathResolver +import org.apache.spark.network.netty.{PathResolver, ShuffleSender} import org.apache.spark.util.Utils import org.apache.spark.shuffle.sort.SortShuffleManager @@ -52,6 +52,7 @@ private[spark] class DiskBlockManager(shuffleBlockManager: ShuffleBlockManager, System.exit(ExecutorExitCode.DISK_STORE_FAILED_TO_CREATE_DIR) } private val subDirs = Array.fill(localDirs.length)(new Array[File](subDirsPerLocalDir)) + private var shuffleSender : ShuffleSender = null addShutdownHook() @@ -185,5 +186,15 @@ private[spark] class DiskBlockManager(shuffleBlockManager: ShuffleBlockManager, } } } + + if (shuffleSender != null) { + shuffleSender.stop() + } + } + + private[storage] def startShuffleBlockSender(port: Int): Int = { + shuffleSender = new ShuffleSender(port, this) + logInfo(s"Created ShuffleSender binding to port: ${shuffleSender.port}") + shuffleSender.port } } diff --git a/core/src/test/resources/netty-test-file.txt b/core/src/test/resources/netty-test-file.txt deleted file mode 100644 index f59f293ee02ea..0000000000000 --- a/core/src/test/resources/netty-test-file.txt +++ /dev/null @@ -1,1379 +0,0 @@ -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb -eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee -aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa \ No newline at end of file diff --git a/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala deleted file mode 100644 index ef3478a41e912..0000000000000 --- a/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala +++ /dev/null @@ -1,158 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty - -import java.io.{RandomAccessFile, File} -import java.nio.ByteBuffer -import java.util.{Collections, HashSet} -import java.util.concurrent.{TimeUnit, Semaphore} - -import scala.collection.JavaConversions._ - -import io.netty.buffer.{ByteBufUtil, Unpooled} - -import org.scalatest.{BeforeAndAfterAll, FunSuite} - -import org.apache.spark.SparkConf -import org.apache.spark.network.netty.client.{ReferenceCountedBuffer, BlockFetchingClientFactory} -import org.apache.spark.network.netty.server.BlockServer -import org.apache.spark.storage.{FileSegment, BlockDataProvider} - - -/** - * Test suite that makes sure the server and the client implementations share the same protocol. - */ -class ServerClientIntegrationSuite extends FunSuite with BeforeAndAfterAll { - - val bufSize = 100000 - var buf: ByteBuffer = _ - var testFile: File = _ - var server: BlockServer = _ - var clientFactory: BlockFetchingClientFactory = _ - - val bufferBlockId = "buffer_block" - val fileBlockId = "file_block" - - val fileContent = new Array[Byte](1024) - scala.util.Random.nextBytes(fileContent) - - override def beforeAll() = { - buf = ByteBuffer.allocate(bufSize) - for (i <- 1 to bufSize) { - buf.put(i.toByte) - } - buf.flip() - - testFile = File.createTempFile("netty-test-file", "txt") - val fp = new RandomAccessFile(testFile, "rw") - fp.write(fileContent) - fp.close() - - server = new BlockServer(new SparkConf, new BlockDataProvider { - override def getBlockData(blockId: String): Either[FileSegment, ByteBuffer] = { - if (blockId == bufferBlockId) { - Right(buf) - } else if (blockId == fileBlockId) { - Left(new FileSegment(testFile, 10, testFile.length - 25)) - } else { - throw new Exception("Unknown block id " + blockId) - } - } - }) - - clientFactory = new BlockFetchingClientFactory(new SparkConf) - } - - override def afterAll() = { - server.stop() - clientFactory.stop() - } - - /** A ByteBuf for buffer_block */ - lazy val byteBufferBlockReference = Unpooled.wrappedBuffer(buf) - - /** A ByteBuf for file_block */ - lazy val fileBlockReference = Unpooled.wrappedBuffer(fileContent, 10, fileContent.length - 25) - - def fetchBlocks(blockIds: Seq[String]): (Set[String], Set[ReferenceCountedBuffer], Set[String]) = - { - val client = clientFactory.createClient(server.hostName, server.port) - val sem = new Semaphore(0) - val receivedBlockIds = Collections.synchronizedSet(new HashSet[String]) - val errorBlockIds = Collections.synchronizedSet(new HashSet[String]) - val receivedBuffers = Collections.synchronizedSet(new HashSet[ReferenceCountedBuffer]) - - client.fetchBlocks( - blockIds, - (blockId, buf) => { - receivedBlockIds.add(blockId) - buf.retain() - receivedBuffers.add(buf) - sem.release() - }, - (blockId, errorMsg) => { - errorBlockIds.add(blockId) - sem.release() - } - ) - if (!sem.tryAcquire(blockIds.size, 30, TimeUnit.SECONDS)) { - fail("Timeout getting response from the server") - } - client.close() - (receivedBlockIds.toSet, receivedBuffers.toSet, errorBlockIds.toSet) - } - - test("fetch a ByteBuffer block") { - val (blockIds, buffers, failBlockIds) = fetchBlocks(Seq(bufferBlockId)) - assert(blockIds === Set(bufferBlockId)) - assert(buffers.map(_.underlying) === Set(byteBufferBlockReference)) - assert(failBlockIds.isEmpty) - buffers.foreach(_.release()) - } - - test("fetch a FileSegment block via zero-copy send") { - val (blockIds, buffers, failBlockIds) = fetchBlocks(Seq(fileBlockId)) - assert(blockIds === Set(fileBlockId)) - assert(buffers.map(_.underlying) === Set(fileBlockReference)) - assert(failBlockIds.isEmpty) - buffers.foreach(_.release()) - } - - test("fetch a non-existent block") { - val (blockIds, buffers, failBlockIds) = fetchBlocks(Seq("random-block")) - assert(blockIds.isEmpty) - assert(buffers.isEmpty) - assert(failBlockIds === Set("random-block")) - } - - test("fetch both ByteBuffer block and FileSegment block") { - val (blockIds, buffers, failBlockIds) = fetchBlocks(Seq(bufferBlockId, fileBlockId)) - assert(blockIds === Set(bufferBlockId, fileBlockId)) - assert(buffers.map(_.underlying) === Set(byteBufferBlockReference, fileBlockReference)) - assert(failBlockIds.isEmpty) - buffers.foreach(_.release()) - } - - test("fetch both ByteBuffer block and a non-existent block") { - val (blockIds, buffers, failBlockIds) = fetchBlocks(Seq(bufferBlockId, "random-block")) - assert(blockIds === Set(bufferBlockId)) - assert(buffers.map(_.underlying) === Set(byteBufferBlockReference)) - assert(failBlockIds === Set("random-block")) - buffers.foreach(_.release()) - } -} diff --git a/core/src/test/scala/org/apache/spark/network/netty/client/BlockFetchingClientHandlerSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/client/BlockFetchingClientHandlerSuite.scala deleted file mode 100644 index 9afdad63b6988..0000000000000 --- a/core/src/test/scala/org/apache/spark/network/netty/client/BlockFetchingClientHandlerSuite.scala +++ /dev/null @@ -1,87 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty.client - -import java.nio.ByteBuffer - -import io.netty.buffer.Unpooled -import io.netty.channel.embedded.EmbeddedChannel - -import org.scalatest.FunSuite - - -class BlockFetchingClientHandlerSuite extends FunSuite { - - test("handling block data (successful fetch)") { - val blockId = "test_block" - val blockData = "blahblahblahblahblah" - val totalLength = 4 + blockId.length + blockData.length - - var parsedBlockId: String = "" - var parsedBlockData: String = "" - val handler = new BlockFetchingClientHandler - handler.blockFetchSuccessCallback = (bid, refCntBuf) => { - parsedBlockId = bid - val bytes = new Array[Byte](refCntBuf.byteBuffer().remaining) - refCntBuf.byteBuffer().get(bytes) - parsedBlockData = new String(bytes) - } - - val channel = new EmbeddedChannel(handler) - val buf = ByteBuffer.allocate(totalLength + 4) // 4 bytes for the length field itself - buf.putInt(totalLength) - buf.putInt(blockId.length) - buf.put(blockId.getBytes) - buf.put(blockData.getBytes) - buf.flip() - - channel.writeInbound(Unpooled.wrappedBuffer(buf)) - assert(parsedBlockId === blockId) - assert(parsedBlockData === blockData) - - channel.close() - } - - test("handling error message (failed fetch)") { - val blockId = "test_block" - val errorMsg = "error erro5r error err4or error3 error6 error erro1r" - val totalLength = 4 + blockId.length + errorMsg.length - - var parsedBlockId: String = "" - var parsedErrorMsg: String = "" - val handler = new BlockFetchingClientHandler - handler.blockFetchFailureCallback = (bid, msg) => { - parsedBlockId = bid - parsedErrorMsg = msg - } - - val channel = new EmbeddedChannel(handler) - val buf = ByteBuffer.allocate(totalLength + 4) // 4 bytes for the length field itself - buf.putInt(totalLength) - buf.putInt(-blockId.length) - buf.put(blockId.getBytes) - buf.put(errorMsg.getBytes) - buf.flip() - - channel.writeInbound(Unpooled.wrappedBuffer(buf)) - assert(parsedBlockId === blockId) - assert(parsedErrorMsg === errorMsg) - - channel.close() - } -} diff --git a/core/src/test/scala/org/apache/spark/network/netty/server/BlockHeaderEncoderSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/server/BlockHeaderEncoderSuite.scala deleted file mode 100644 index 3ee281cb1350b..0000000000000 --- a/core/src/test/scala/org/apache/spark/network/netty/server/BlockHeaderEncoderSuite.scala +++ /dev/null @@ -1,64 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty.server - -import io.netty.buffer.ByteBuf -import io.netty.channel.embedded.EmbeddedChannel - -import org.scalatest.FunSuite - - -class BlockHeaderEncoderSuite extends FunSuite { - - test("encode normal block data") { - val blockId = "test_block" - val channel = new EmbeddedChannel(new BlockHeaderEncoder) - channel.writeOutbound(new BlockHeader(17, blockId, None)) - val out = channel.readOutbound().asInstanceOf[ByteBuf] - assert(out.readInt() === 4 + blockId.length + 17) - assert(out.readInt() === blockId.length) - - val blockIdBytes = new Array[Byte](blockId.length) - out.readBytes(blockIdBytes) - assert(new String(blockIdBytes) === blockId) - assert(out.readableBytes() === 0) - - channel.close() - } - - test("encode error message") { - val blockId = "error_block" - val errorMsg = "error encountered" - val channel = new EmbeddedChannel(new BlockHeaderEncoder) - channel.writeOutbound(new BlockHeader(17, blockId, Some(errorMsg))) - val out = channel.readOutbound().asInstanceOf[ByteBuf] - assert(out.readInt() === 4 + blockId.length + errorMsg.length) - assert(out.readInt() === -blockId.length) - - val blockIdBytes = new Array[Byte](blockId.length) - out.readBytes(blockIdBytes) - assert(new String(blockIdBytes) === blockId) - - val errorMsgBytes = new Array[Byte](errorMsg.length) - out.readBytes(errorMsgBytes) - assert(new String(errorMsgBytes) === errorMsg) - assert(out.readableBytes() === 0) - - channel.close() - } -} diff --git a/core/src/test/scala/org/apache/spark/network/netty/server/BlockServerHandlerSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/server/BlockServerHandlerSuite.scala deleted file mode 100644 index 12f6d87616644..0000000000000 --- a/core/src/test/scala/org/apache/spark/network/netty/server/BlockServerHandlerSuite.scala +++ /dev/null @@ -1,101 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty.server - -import java.io.File -import java.nio.ByteBuffer - -import io.netty.buffer.{Unpooled, ByteBuf} -import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler, DefaultFileRegion} -import io.netty.channel.embedded.EmbeddedChannel - -import org.scalatest.FunSuite - -import org.apache.spark.storage.{BlockDataProvider, FileSegment} - - -class BlockServerHandlerSuite extends FunSuite { - - test("ByteBuffer block") { - val expectedBlockId = "test_bytebuffer_block" - val buf = ByteBuffer.allocate(10000) - for (i <- 1 to 10000) { - buf.put(i.toByte) - } - buf.flip() - - val channel = new EmbeddedChannel(new BlockServerHandler(new BlockDataProvider { - override def getBlockData(blockId: String): Either[FileSegment, ByteBuffer] = Right(buf) - })) - - channel.writeInbound(expectedBlockId) - assert(channel.outboundMessages().size === 2) - - val out1 = channel.readOutbound().asInstanceOf[BlockHeader] - val out2 = channel.readOutbound().asInstanceOf[ByteBuf] - - assert(out1.blockId === expectedBlockId) - assert(out1.blockSize === buf.remaining) - assert(out1.error === None) - - assert(out2.equals(Unpooled.wrappedBuffer(buf))) - - channel.close() - } - - test("FileSegment block via zero-copy") { - val expectedBlockId = "test_file_block" - val url = Thread.currentThread.getContextClassLoader.getResource("netty-test-file.txt") - val testFile = new File(url.toURI) - - val channel = new EmbeddedChannel(new BlockServerHandler(new BlockDataProvider { - override def getBlockData(blockId: String): Either[FileSegment, ByteBuffer] = { - Left(new FileSegment(testFile, 15, testFile.length - 25)) - } - })) - - channel.writeInbound(expectedBlockId) - assert(channel.outboundMessages().size === 2) - - val out1 = channel.readOutbound().asInstanceOf[BlockHeader] - val out2 = channel.readOutbound().asInstanceOf[DefaultFileRegion] - - assert(out1.blockId === expectedBlockId) - assert(out1.blockSize === testFile.length - 25) - assert(out1.error === None) - - assert(out2.count === testFile.length - 25) - assert(out2.position === 15) - } - - test("pipeline exception propagation") { - val blockServerHandler = new BlockServerHandler(new BlockDataProvider { - override def getBlockData(blockId: String): Either[FileSegment, ByteBuffer] = ??? - }) - val exceptionHandler = new SimpleChannelInboundHandler[String]() { - override def channelRead0(ctx: ChannelHandlerContext, msg: String): Unit = { - throw new Exception("this is an error") - } - } - - val channel = new EmbeddedChannel(exceptionHandler, blockServerHandler) - assert(channel.isOpen) - channel.writeInbound("a message to trigger the error") - assert(!channel.isOpen) - } -} diff --git a/pom.xml b/pom.xml index da401c9753347..c87f776bda659 100644 --- a/pom.xml +++ b/pom.xml @@ -419,7 +419,7 @@ io.netty netty-all - 4.0.22.Final + 4.0.17.Final org.apache.derby From 63376a0eeffa611ccfdf1e023bc0cf3393d70139 Mon Sep 17 00:00:00 2001 From: Sandy Ryza Date: Fri, 15 Aug 2014 11:35:08 -0700 Subject: [PATCH 148/231] SPARK-3028. sparkEventToJson should support SparkListenerExecutorMetrics... ...Update Author: Sandy Ryza Closes #1961 from sryza/sandy-spark-3028 and squashes the following commits: dccdff5 [Sandy Ryza] Fix compile error f883ded [Sandy Ryza] SPARK-3028. sparkEventToJson should support SparkListenerExecutorMetricsUpdate (cherry picked from commit 0afe5cb65a195d2f14e8dfcefdbec5dac023651f) Signed-off-by: Patrick Wendell --- .../org/apache/spark/scheduler/EventLoggingListener.scala | 2 ++ core/src/main/scala/org/apache/spark/util/JsonProtocol.scala | 3 ++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala index 406147f167bf3..7378ce923f0ae 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala @@ -127,6 +127,8 @@ private[spark] class EventLoggingListener( logEvent(event, flushLogger = true) override def onApplicationEnd(event: SparkListenerApplicationEnd) = logEvent(event, flushLogger = true) + // No-op because logging every update would be overkill + override def onExecutorMetricsUpdate(event: SparkListenerExecutorMetricsUpdate) { } /** * Stop logging events. diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index 6f8eb1ee12634..1e18ec688c40d 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -72,8 +72,9 @@ private[spark] object JsonProtocol { case applicationEnd: SparkListenerApplicationEnd => applicationEndToJson(applicationEnd) - // Not used, but keeps compiler happy + // These aren't used, but keeps compiler happy case SparkListenerShutdown => JNothing + case SparkListenerExecutorMetricsUpdate(_, _) => JNothing } } From 407ea9fd6f68ff3237726841b80dec61cbc7f51c Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Fri, 15 Aug 2014 14:50:10 -0700 Subject: [PATCH 149/231] [SPARK-3022] [SPARK-3041] [mllib] Call findBins once per level + unordered feature bug fix DecisionTree improvements: (1) TreePoint representation to avoid binning multiple times (2) Bug fix: isSampleValid indexed bins incorrectly for unordered categorical features (3) Timing for DecisionTree internals Details: (1) TreePoint representation to avoid binning multiple times [https://issues.apache.org/jira/browse/SPARK-3022] Added private[tree] TreePoint class for representing binned feature values. The input RDD of LabeledPoint is converted to the TreePoint representation initially and then cached. This avoids the previous problem of re-computing bins multiple times. (2) Bug fix: isSampleValid indexed bins incorrectly for unordered categorical features [https://issues.apache.org/jira/browse/SPARK-3041] isSampleValid used to treat unordered categorical features incorrectly: It treated the bins as if indexed by featured values, rather than by subsets of values/categories. * exhibited for unordered features (multi-class classification with categorical features of low arity) * Fix: Index bins correctly for unordered categorical features. (3) Timing for DecisionTree internals Added tree/impl/TimeTracker.scala class which is private[tree] for now, for timing key parts of DT code. Prints timing info via logDebug. CC: mengxr manishamde chouqin Very similar update, with one bug fix. Many apologies for the conflicting update, but I hope that a few more optimizations I have on the way (which depend on this update) will prove valuable to you: SPARK-3042 and SPARK-3043 Author: Joseph K. Bradley Closes #1950 from jkbradley/dt-opt1 and squashes the following commits: 5f2dec2 [Joseph K. Bradley] Fixed scalastyle issue in TreePoint 6b5651e [Joseph K. Bradley] Updates based on code review. 1 major change: persisting to memory + disk, not just memory. 2d2aaaf [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-opt1 430d782 [Joseph K. Bradley] Added more debug info on binning error. Added some docs. d036089 [Joseph K. Bradley] Print timing info to logDebug. e66f1b1 [Joseph K. Bradley] TreePoint * Updated doc * Made some methods private 8464a6e [Joseph K. Bradley] Moved TimeTracker to tree/impl/ in its own file, and cleaned it up. Removed debugging println calls from DecisionTree. Made TreePoint extend Serialiable a87e08f [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-opt1 0f676e2 [Joseph K. Bradley] Optimizations + Bug fix for DecisionTree 3211f02 [Joseph K. Bradley] Optimizing DecisionTree * Added TreePoint representation to avoid calling findBin multiple times. * (not working yet, but debugging) f61e9d2 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-timing bcf874a [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-timing 511ec85 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-timing a95bc22 [Joseph K. Bradley] timing for DecisionTree internals (cherry picked from commit c7032290a3f0f5545aa4f0a9a144c62571344dc8) Signed-off-by: Xiangrui Meng --- .../spark/mllib/tree/DecisionTree.scala | 289 ++++++++---------- .../mllib/tree/configuration/Strategy.scala | 43 ++- .../spark/mllib/tree/impl/TimeTracker.scala | 73 +++++ .../spark/mllib/tree/impl/TreePoint.scala | 201 ++++++++++++ .../spark/mllib/tree/DecisionTreeSuite.scala | 50 +-- 5 files changed, 449 insertions(+), 207 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TimeTracker.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index bb50f07be5d7b..2a3107a13e916 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -17,22 +17,24 @@ package org.apache.spark.mllib.tree -import org.apache.spark.api.java.JavaRDD - import scala.collection.JavaConverters._ import org.apache.spark.annotation.Experimental +import org.apache.spark.api.java.JavaRDD import org.apache.spark.Logging import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.tree.configuration.{Algo, Strategy} +import org.apache.spark.mllib.tree.configuration.Strategy import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.FeatureType._ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ -import org.apache.spark.mllib.tree.impurity.{Impurities, Gini, Entropy, Impurity} +import org.apache.spark.mllib.tree.impl.{TimeTracker, TreePoint} +import org.apache.spark.mllib.tree.impurity.{Impurities, Impurity} import org.apache.spark.mllib.tree.model._ import org.apache.spark.rdd.RDD +import org.apache.spark.storage.StorageLevel import org.apache.spark.util.random.XORShiftRandom + /** * :: Experimental :: * A class which implements a decision tree learning algorithm for classification and regression. @@ -53,16 +55,27 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo */ def train(input: RDD[LabeledPoint]): DecisionTreeModel = { - // Cache input RDD for speedup during multiple passes. - val retaggedInput = input.retag(classOf[LabeledPoint]).cache() + val timer = new TimeTracker() + + timer.start("total") + + timer.start("init") + + val retaggedInput = input.retag(classOf[LabeledPoint]) logDebug("algo = " + strategy.algo) // Find the splits and the corresponding bins (interval between the splits) using a sample // of the input data. + timer.start("findSplitsBins") val (splits, bins) = DecisionTree.findSplitsBins(retaggedInput, strategy) val numBins = bins(0).length + timer.stop("findSplitsBins") logDebug("numBins = " + numBins) + // Cache input RDD for speedup during multiple passes. + val treeInput = TreePoint.convertToTreeRDD(retaggedInput, strategy, bins) + .persist(StorageLevel.MEMORY_AND_DISK) + // depth of the decision tree val maxDepth = strategy.maxDepth // the max number of nodes possible given the depth of the tree @@ -76,7 +89,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo // dummy value for top node (updated during first split calculation) val nodes = new Array[Node](maxNumNodes) // num features - val numFeatures = retaggedInput.take(1)(0).features.size + val numFeatures = treeInput.take(1)(0).binnedFeatures.size // Calculate level for single group construction @@ -96,6 +109,8 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo (math.log(maxNumberOfNodesPerGroup) / math.log(2)).floor.toInt, 0) logDebug("max level for single group = " + maxLevelForSingleGroup) + timer.stop("init") + /* * The main idea here is to perform level-wise training of the decision tree nodes thus * reducing the passes over the data from l to log2(l) where l is the total number of nodes. @@ -113,15 +128,21 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo logDebug("#####################################") // Find best split for all nodes at a level. - val splitsStatsForLevel = DecisionTree.findBestSplits(retaggedInput, parentImpurities, - strategy, level, filters, splits, bins, maxLevelForSingleGroup) + timer.start("findBestSplits") + val splitsStatsForLevel = DecisionTree.findBestSplits(treeInput, parentImpurities, + strategy, level, filters, splits, bins, maxLevelForSingleGroup, timer) + timer.stop("findBestSplits") for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) { + timer.start("extractNodeInfo") // Extract info for nodes at the current level. extractNodeInfo(nodeSplitStats, level, index, nodes) + timer.stop("extractNodeInfo") + timer.start("extractInfoForLowerLevels") // Extract info for nodes at the next lower level. extractInfoForLowerLevels(level, index, maxDepth, nodeSplitStats, parentImpurities, filters) + timer.stop("extractInfoForLowerLevels") logDebug("final best split = " + nodeSplitStats._1) } require(math.pow(2, level) == splitsStatsForLevel.length) @@ -144,6 +165,11 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo // Build the full tree using the node info calculated in the level-wise best split calculations. topNode.build(nodes) + timer.stop("total") + + logInfo("Internal timing for DecisionTree:") + logInfo(s"$timer") + new DecisionTreeModel(topNode, strategy.algo) } @@ -406,7 +432,7 @@ object DecisionTree extends Serializable with Logging { * Returns an array of optimal splits for all nodes at a given level. Splits the task into * multiple groups if the level-wise training task could lead to memory overflow. * - * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] + * @param input Training data: RDD of [[org.apache.spark.mllib.tree.impl.TreePoint]] * @param parentImpurities Impurities for all parent nodes for the current level * @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing * parameters for constructing the DecisionTree @@ -415,44 +441,45 @@ object DecisionTree extends Serializable with Logging { * @param splits possible splits for all features * @param bins possible bins for all features * @param maxLevelForSingleGroup the deepest level for single-group level-wise computation. - * @return array of splits with best splits for all nodes at a given level. + * @return array (over nodes) of splits with best split for each node at a given level. */ protected[tree] def findBestSplits( - input: RDD[LabeledPoint], + input: RDD[TreePoint], parentImpurities: Array[Double], strategy: Strategy, level: Int, filters: Array[List[Filter]], splits: Array[Array[Split]], bins: Array[Array[Bin]], - maxLevelForSingleGroup: Int): Array[(Split, InformationGainStats)] = { + maxLevelForSingleGroup: Int, + timer: TimeTracker = new TimeTracker): Array[(Split, InformationGainStats)] = { // split into groups to avoid memory overflow during aggregation if (level > maxLevelForSingleGroup) { // When information for all nodes at a given level cannot be stored in memory, // the nodes are divided into multiple groups at each level with the number of groups // increasing exponentially per level. For example, if maxLevelForSingleGroup is 10, // numGroups is equal to 2 at level 11 and 4 at level 12, respectively. - val numGroups = math.pow(2, (level - maxLevelForSingleGroup)).toInt + val numGroups = math.pow(2, level - maxLevelForSingleGroup).toInt logDebug("numGroups = " + numGroups) var bestSplits = new Array[(Split, InformationGainStats)](0) // Iterate over each group of nodes at a level. var groupIndex = 0 while (groupIndex < numGroups) { val bestSplitsForGroup = findBestSplitsPerGroup(input, parentImpurities, strategy, level, - filters, splits, bins, numGroups, groupIndex) + filters, splits, bins, timer, numGroups, groupIndex) bestSplits = Array.concat(bestSplits, bestSplitsForGroup) groupIndex += 1 } bestSplits } else { - findBestSplitsPerGroup(input, parentImpurities, strategy, level, filters, splits, bins) + findBestSplitsPerGroup(input, parentImpurities, strategy, level, filters, splits, bins, timer) } } /** * Returns an array of optimal splits for a group of nodes at a given level * - * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] + * @param input Training data: RDD of [[org.apache.spark.mllib.tree.impl.TreePoint]] * @param parentImpurities Impurities for all parent nodes for the current level * @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing * parameters for constructing the DecisionTree @@ -465,13 +492,14 @@ object DecisionTree extends Serializable with Logging { * @return array of splits with best splits for all nodes at a given level. */ private def findBestSplitsPerGroup( - input: RDD[LabeledPoint], + input: RDD[TreePoint], parentImpurities: Array[Double], strategy: Strategy, level: Int, filters: Array[List[Filter]], splits: Array[Array[Split]], bins: Array[Array[Bin]], + timer: TimeTracker, numGroups: Int = 1, groupIndex: Int = 0): Array[(Split, InformationGainStats)] = { @@ -507,7 +535,7 @@ object DecisionTree extends Serializable with Logging { logDebug("numNodes = " + numNodes) // Find the number of features by looking at the first sample. - val numFeatures = input.first().features.size + val numFeatures = input.first().binnedFeatures.size logDebug("numFeatures = " + numFeatures) // numBins: Number of bins = 1 + number of possible splits @@ -542,33 +570,43 @@ object DecisionTree extends Serializable with Logging { * Find whether the sample is valid input for the current node, i.e., whether it passes through * all the filters for the current node. */ - def isSampleValid(parentFilters: List[Filter], labeledPoint: LabeledPoint): Boolean = { + def isSampleValid(parentFilters: List[Filter], treePoint: TreePoint): Boolean = { // leaf if ((level > 0) && (parentFilters.length == 0)) { return false } // Apply each filter and check sample validity. Return false when invalid condition found. - for (filter <- parentFilters) { - val features = labeledPoint.features + parentFilters.foreach { filter => val featureIndex = filter.split.feature - val threshold = filter.split.threshold val comparison = filter.comparison - val categories = filter.split.categories val isFeatureContinuous = filter.split.featureType == Continuous - val feature = features(featureIndex) if (isFeatureContinuous) { + val binId = treePoint.binnedFeatures(featureIndex) + val bin = bins(featureIndex)(binId) + val featureValue = bin.highSplit.threshold + val threshold = filter.split.threshold comparison match { - case -1 => if (feature > threshold) return false - case 1 => if (feature <= threshold) return false + case -1 => if (featureValue > threshold) return false + case 1 => if (featureValue <= threshold) return false } } else { - val containsFeature = categories.contains(feature) + val numFeatureCategories = strategy.categoricalFeaturesInfo(featureIndex) + val isSpaceSufficientForAllCategoricalSplits = + numBins > math.pow(2, numFeatureCategories.toInt - 1) - 1 + val isUnorderedFeature = + isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits + val featureValue = if (isUnorderedFeature) { + treePoint.binnedFeatures(featureIndex) + } else { + val binId = treePoint.binnedFeatures(featureIndex) + bins(featureIndex)(binId).category + } + val containsFeature = filter.split.categories.contains(featureValue) comparison match { case -1 => if (!containsFeature) return false case 1 => if (containsFeature) return false } - } } @@ -576,103 +614,6 @@ object DecisionTree extends Serializable with Logging { true } - /** - * Find bin for one (labeledPoint, feature). - */ - def findBin( - featureIndex: Int, - labeledPoint: LabeledPoint, - isFeatureContinuous: Boolean, - isSpaceSufficientForAllCategoricalSplits: Boolean): Int = { - val binForFeatures = bins(featureIndex) - val feature = labeledPoint.features(featureIndex) - - /** - * Binary search helper method for continuous feature. - */ - def binarySearchForBins(): Int = { - var left = 0 - var right = binForFeatures.length - 1 - while (left <= right) { - val mid = left + (right - left) / 2 - val bin = binForFeatures(mid) - val lowThreshold = bin.lowSplit.threshold - val highThreshold = bin.highSplit.threshold - if ((lowThreshold < feature) && (highThreshold >= feature)) { - return mid - } - else if (lowThreshold >= feature) { - right = mid - 1 - } - else { - left = mid + 1 - } - } - -1 - } - - /** - * Sequential search helper method to find bin for categorical feature in multiclass - * classification. The category is returned since each category can belong to multiple - * splits. The actual left/right child allocation per split is performed in the - * sequential phase of the bin aggregate operation. - */ - def sequentialBinSearchForUnorderedCategoricalFeatureInClassification(): Int = { - labeledPoint.features(featureIndex).toInt - } - - /** - * Sequential search helper method to find bin for categorical feature - * (for classification and regression). - */ - def sequentialBinSearchForOrderedCategoricalFeature(): Int = { - val featureCategories = strategy.categoricalFeaturesInfo(featureIndex) - val featureValue = labeledPoint.features(featureIndex) - var binIndex = 0 - while (binIndex < featureCategories) { - val bin = bins(featureIndex)(binIndex) - val categories = bin.highSplit.categories - if (categories.contains(featureValue)) { - return binIndex - } - binIndex += 1 - } - if (featureValue < 0 || featureValue >= featureCategories) { - throw new IllegalArgumentException( - s"DecisionTree given invalid data:" + - s" Feature $featureIndex is categorical with values in" + - s" {0,...,${featureCategories - 1}," + - s" but a data point gives it value $featureValue.\n" + - " Bad data point: " + labeledPoint.toString) - } - -1 - } - - if (isFeatureContinuous) { - // Perform binary search for finding bin for continuous features. - val binIndex = binarySearchForBins() - if (binIndex == -1) { - throw new UnknownError("no bin was found for continuous variable.") - } - binIndex - } else { - // Perform sequential search to find bin for categorical features. - val binIndex = { - val isUnorderedFeature = - isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits - if (isUnorderedFeature) { - sequentialBinSearchForUnorderedCategoricalFeatureInClassification() - } else { - sequentialBinSearchForOrderedCategoricalFeature() - } - } - if (binIndex == -1) { - throw new UnknownError("no bin was found for categorical variable.") - } - binIndex - } - } - /** * Finds bins for all nodes (and all features) at a given level. * For l nodes, k features the storage is as follows: @@ -689,17 +630,17 @@ object DecisionTree extends Serializable with Logging { * bin index for this labeledPoint * (or InvalidBinIndex if labeledPoint is not handled by this node) */ - def findBinsForLevel(labeledPoint: LabeledPoint): Array[Double] = { + def findBinsForLevel(treePoint: TreePoint): Array[Double] = { // Calculate bin index and label per feature per node. val arr = new Array[Double](1 + (numFeatures * numNodes)) // First element of the array is the label of the instance. - arr(0) = labeledPoint.label + arr(0) = treePoint.label // Iterate over nodes. var nodeIndex = 0 while (nodeIndex < numNodes) { val parentFilters = findParentFilters(nodeIndex) // Find out whether the sample qualifies for the particular node. - val sampleValid = isSampleValid(parentFilters, labeledPoint) + val sampleValid = isSampleValid(parentFilters, treePoint) val shift = 1 + numFeatures * nodeIndex if (!sampleValid) { // Mark one bin as -1 is sufficient. @@ -707,19 +648,7 @@ object DecisionTree extends Serializable with Logging { } else { var featureIndex = 0 while (featureIndex < numFeatures) { - val featureInfo = strategy.categoricalFeaturesInfo.get(featureIndex) - val isFeatureContinuous = featureInfo.isEmpty - if (isFeatureContinuous) { - arr(shift + featureIndex) - = findBin(featureIndex, labeledPoint, isFeatureContinuous, false) - } else { - val featureCategories = featureInfo.get - val isSpaceSufficientForAllCategoricalSplits - = numBins > math.pow(2, featureCategories.toInt - 1) - 1 - arr(shift + featureIndex) - = findBin(featureIndex, labeledPoint, isFeatureContinuous, - isSpaceSufficientForAllCategoricalSplits) - } + arr(shift + featureIndex) = treePoint.binnedFeatures(featureIndex) featureIndex += 1 } } @@ -728,7 +657,8 @@ object DecisionTree extends Serializable with Logging { arr } - // Find feature bins for all nodes at a level. + // Find feature bins for all nodes at a level. + timer.start("aggregation") val binMappedRDD = input.map(x => findBinsForLevel(x)) /** @@ -830,6 +760,8 @@ object DecisionTree extends Serializable with Logging { } } + val rightChildShift = numClasses * numBins * numFeatures * numNodes + /** * Helper for binSeqOp. * @@ -853,7 +785,6 @@ object DecisionTree extends Serializable with Logging { val validSignalIndex = 1 + numFeatures * nodeIndex val isSampleValidForNode = arr(validSignalIndex) != InvalidBinIndex if (isSampleValidForNode) { - val rightChildShift = numClasses * numBins * numFeatures * numNodes // actual class label val label = arr(0) // Iterate over all features. @@ -912,7 +843,7 @@ object DecisionTree extends Serializable with Logging { val aggIndex = aggShift + 3 * featureIndex * numBins + arr(arrIndex).toInt * 3 agg(aggIndex) = agg(aggIndex) + 1 agg(aggIndex + 1) = agg(aggIndex + 1) + label - agg(aggIndex + 2) = agg(aggIndex + 2) + label*label + agg(aggIndex + 2) = agg(aggIndex + 2) + label * label featureIndex += 1 } } @@ -977,6 +908,7 @@ object DecisionTree extends Serializable with Logging { val binAggregates = { binMappedRDD.aggregate(Array.fill[Double](binAggregateLength)(0))(binSeqOp,binCombOp) } + timer.stop("aggregation") logDebug("binAggregates.length = " + binAggregates.length) /** @@ -1031,10 +963,17 @@ object DecisionTree extends Serializable with Logging { def indexOfLargestArrayElement(array: Array[Double]): Int = { val result = array.foldLeft(-1, Double.MinValue, 0) { case ((maxIndex, maxValue, currentIndex), currentValue) => - if(currentValue > maxValue) (currentIndex, currentValue, currentIndex + 1) - else (maxIndex, maxValue, currentIndex + 1) + if (currentValue > maxValue) { + (currentIndex, currentValue, currentIndex + 1) + } else { + (maxIndex, maxValue, currentIndex + 1) + } + } + if (result._1 < 0) { + throw new RuntimeException("DecisionTree internal error:" + + " calculateGainForSplit failed in indexOfLargestArrayElement") } - if (result._1 < 0) 0 else result._1 + result._1 } val predict = indexOfLargestArrayElement(leftRightCounts) @@ -1057,6 +996,7 @@ object DecisionTree extends Serializable with Logging { val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict, prob) + case Regression => val leftCount = leftNodeAgg(featureIndex)(splitIndex)(0) val leftSum = leftNodeAgg(featureIndex)(splitIndex)(1) @@ -1280,15 +1220,41 @@ object DecisionTree extends Serializable with Logging { nodeImpurity: Double): Array[Array[InformationGainStats]] = { val gains = Array.ofDim[InformationGainStats](numFeatures, numBins - 1) - for (featureIndex <- 0 until numFeatures) { - for (splitIndex <- 0 until numBins - 1) { + var featureIndex = 0 + while (featureIndex < numFeatures) { + val numSplitsForFeature = getNumSplitsForFeature(featureIndex) + var splitIndex = 0 + while (splitIndex < numSplitsForFeature) { gains(featureIndex)(splitIndex) = calculateGainForSplit(leftNodeAgg, featureIndex, splitIndex, rightNodeAgg, nodeImpurity) + splitIndex += 1 } + featureIndex += 1 } gains } + /** + * Get the number of splits for a feature. + */ + def getNumSplitsForFeature(featureIndex: Int): Int = { + val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty + if (isFeatureContinuous) { + numBins - 1 + } else { + // Categorical feature + val featureCategories = strategy.categoricalFeaturesInfo(featureIndex) + val isSpaceSufficientForAllCategoricalSplits = + numBins > math.pow(2, featureCategories.toInt - 1) - 1 + if (isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits) { + math.pow(2.0, featureCategories - 1).toInt - 1 + } else { + // Ordered features + featureCategories + } + } + } + /** * Find the best split for a node. * @param binData Bin data slice for this node, given by getBinDataForNode. @@ -1307,7 +1273,7 @@ object DecisionTree extends Serializable with Logging { // Calculate gains for all splits. val gains = calculateGainsForAllNodeSplits(leftNodeAgg, rightNodeAgg, nodeImpurity) - val (bestFeatureIndex,bestSplitIndex, gainStats) = { + val (bestFeatureIndex, bestSplitIndex, gainStats) = { // Initialize with infeasible values. var bestFeatureIndex = Int.MinValue var bestSplitIndex = Int.MinValue @@ -1317,22 +1283,8 @@ object DecisionTree extends Serializable with Logging { while (featureIndex < numFeatures) { // Iterate over all splits. var splitIndex = 0 - val maxSplitIndex: Double = { - val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty - if (isFeatureContinuous) { - numBins - 1 - } else { // Categorical feature - val featureCategories = strategy.categoricalFeaturesInfo(featureIndex) - val isSpaceSufficientForAllCategoricalSplits - = numBins > math.pow(2, featureCategories.toInt - 1) - 1 - if (isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits) { - math.pow(2.0, featureCategories - 1).toInt - 1 - } else { // Binary classification - featureCategories - } - } - } - while (splitIndex < maxSplitIndex) { + val numSplitsForFeature = getNumSplitsForFeature(featureIndex) + while (splitIndex < numSplitsForFeature) { val gainStats = gains(featureIndex)(splitIndex) if (gainStats.gain > bestGainStats.gain) { bestGainStats = gainStats @@ -1383,6 +1335,7 @@ object DecisionTree extends Serializable with Logging { } // Calculate best splits for all nodes at a given level + timer.start("chooseSplits") val bestSplits = new Array[(Split, InformationGainStats)](numNodes) // Iterating over all nodes at this level var node = 0 @@ -1395,6 +1348,8 @@ object DecisionTree extends Serializable with Logging { bestSplits(node) = binsToBestSplit(binsForNode, parentNodeImpurity) node += 1 } + timer.stop("chooseSplits") + bestSplits } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index f31a503608b22..cfc8192a85abd 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -27,22 +27,30 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ /** * :: Experimental :: * Stores all the configuration options for tree construction - * @param algo classification or regression - * @param impurity criterion used for information gain calculation + * @param algo Learning goal. Supported: + * [[org.apache.spark.mllib.tree.configuration.Algo.Classification]], + * [[org.apache.spark.mllib.tree.configuration.Algo.Regression]] + * @param impurity Criterion used for information gain calculation. + * Supported for Classification: [[org.apache.spark.mllib.tree.impurity.Gini]], + * [[org.apache.spark.mllib.tree.impurity.Entropy]]. + * Supported for Regression: [[org.apache.spark.mllib.tree.impurity.Variance]]. * @param maxDepth Maximum depth of the tree. * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. - * @param numClassesForClassification number of classes for classification. Default value is 2 - * leads to binary classification - * @param maxBins maximum number of bins used for splitting features - * @param quantileCalculationStrategy algorithm for calculating quantiles + * @param numClassesForClassification Number of classes for classification. + * (Ignored for regression.) + * Default value is 2 (binary classification). + * @param maxBins Maximum number of bins used for discretizing continuous features and + * for choosing how to split on features at each node. + * More bins give higher granularity. + * @param quantileCalculationStrategy Algorithm for calculating quantiles. Supported: + * [[org.apache.spark.mllib.tree.configuration.QuantileStrategy.Sort]] * @param categoricalFeaturesInfo A map storing information about the categorical variables and the * number of discrete values they take. For example, an entry (n -> * k) implies the feature n is categorical with k categories 0, * 1, 2, ... , k-1. It's important to note that features are * zero-indexed. - * @param maxMemoryInMB maximum memory in MB allocated to histogram aggregation. Default value is + * @param maxMemoryInMB Maximum memory in MB allocated to histogram aggregation. Default value is * 128 MB. - * */ @Experimental class Strategy ( @@ -64,20 +72,7 @@ class Strategy ( = isMulticlassClassification && (categoricalFeaturesInfo.size > 0) /** - * Java-friendly constructor. - * - * @param algo classification or regression - * @param impurity criterion used for information gain calculation - * @param maxDepth Maximum depth of the tree. - * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. - * @param numClassesForClassification number of classes for classification. Default value is 2 - * leads to binary classification - * @param maxBins maximum number of bins used for splitting features - * @param categoricalFeaturesInfo A map storing information about the categorical variables and - * the number of discrete values they take. For example, an entry - * (n -> k) implies the feature n is categorical with k categories - * 0, 1, 2, ... , k-1. It's important to note that features are - * zero-indexed. + * Java-friendly constructor for [[org.apache.spark.mllib.tree.configuration.Strategy]] */ def this( algo: Algo, @@ -90,6 +85,10 @@ class Strategy ( categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap) } + /** + * Check validity of parameters. + * Throws exception if invalid. + */ private[tree] def assertValid(): Unit = { algo match { case Classification => diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TimeTracker.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TimeTracker.scala new file mode 100644 index 0000000000000..d215d68c4279e --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TimeTracker.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.impl + +import scala.collection.mutable.{HashMap => MutableHashMap} + +import org.apache.spark.annotation.Experimental + +/** + * Time tracker implementation which holds labeled timers. + */ +@Experimental +private[tree] class TimeTracker extends Serializable { + + private val starts: MutableHashMap[String, Long] = new MutableHashMap[String, Long]() + + private val totals: MutableHashMap[String, Long] = new MutableHashMap[String, Long]() + + /** + * Starts a new timer, or re-starts a stopped timer. + */ + def start(timerLabel: String): Unit = { + val currentTime = System.nanoTime() + if (starts.contains(timerLabel)) { + throw new RuntimeException(s"TimeTracker.start(timerLabel) called again on" + + s" timerLabel = $timerLabel before that timer was stopped.") + } + starts(timerLabel) = currentTime + } + + /** + * Stops a timer and returns the elapsed time in seconds. + */ + def stop(timerLabel: String): Double = { + val currentTime = System.nanoTime() + if (!starts.contains(timerLabel)) { + throw new RuntimeException(s"TimeTracker.stop(timerLabel) called on" + + s" timerLabel = $timerLabel, but that timer was not started.") + } + val elapsed = currentTime - starts(timerLabel) + starts.remove(timerLabel) + if (totals.contains(timerLabel)) { + totals(timerLabel) += elapsed + } else { + totals(timerLabel) = elapsed + } + elapsed / 1e9 + } + + /** + * Print all timing results in seconds. + */ + override def toString: String = { + totals.map { case (label, elapsed) => + s" $label: ${elapsed / 1e9}" + }.mkString("\n") + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala new file mode 100644 index 0000000000000..ccac1031fd9d9 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala @@ -0,0 +1,201 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.impl + +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.configuration.Strategy +import org.apache.spark.mllib.tree.model.Bin +import org.apache.spark.rdd.RDD + + +/** + * Internal representation of LabeledPoint for DecisionTree. + * This bins feature values based on a subsampled of data as follows: + * (a) Continuous features are binned into ranges. + * (b) Unordered categorical features are binned based on subsets of feature values. + * "Unordered categorical features" are categorical features with low arity used in + * multiclass classification. + * (c) Ordered categorical features are binned based on feature values. + * "Ordered categorical features" are categorical features with high arity, + * or any categorical feature used in regression or binary classification. + * + * @param label Label from LabeledPoint + * @param binnedFeatures Binned feature values. + * Same length as LabeledPoint.features, but values are bin indices. + */ +private[tree] class TreePoint(val label: Double, val binnedFeatures: Array[Int]) + extends Serializable { +} + +private[tree] object TreePoint { + + /** + * Convert an input dataset into its TreePoint representation, + * binning feature values in preparation for DecisionTree training. + * @param input Input dataset. + * @param strategy DecisionTree training info, used for dataset metadata. + * @param bins Bins for features, of size (numFeatures, numBins). + * @return TreePoint dataset representation + */ + def convertToTreeRDD( + input: RDD[LabeledPoint], + strategy: Strategy, + bins: Array[Array[Bin]]): RDD[TreePoint] = { + input.map { x => + TreePoint.labeledPointToTreePoint(x, strategy.isMulticlassClassification, bins, + strategy.categoricalFeaturesInfo) + } + } + + /** + * Convert one LabeledPoint into its TreePoint representation. + * @param bins Bins for features, of size (numFeatures, numBins). + * @param categoricalFeaturesInfo Map over categorical features: feature index --> feature arity + */ + private def labeledPointToTreePoint( + labeledPoint: LabeledPoint, + isMulticlassClassification: Boolean, + bins: Array[Array[Bin]], + categoricalFeaturesInfo: Map[Int, Int]): TreePoint = { + + val numFeatures = labeledPoint.features.size + val numBins = bins(0).size + val arr = new Array[Int](numFeatures) + var featureIndex = 0 + while (featureIndex < numFeatures) { + val featureInfo = categoricalFeaturesInfo.get(featureIndex) + val isFeatureContinuous = featureInfo.isEmpty + if (isFeatureContinuous) { + arr(featureIndex) = findBin(featureIndex, labeledPoint, isFeatureContinuous, false, + bins, categoricalFeaturesInfo) + } else { + val featureCategories = featureInfo.get + val isSpaceSufficientForAllCategoricalSplits + = numBins > math.pow(2, featureCategories.toInt - 1) - 1 + val isUnorderedFeature = + isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits + arr(featureIndex) = findBin(featureIndex, labeledPoint, isFeatureContinuous, + isUnorderedFeature, bins, categoricalFeaturesInfo) + } + featureIndex += 1 + } + + new TreePoint(labeledPoint.label, arr) + } + + /** + * Find bin for one (labeledPoint, feature). + * + * @param isUnorderedFeature (only applies if feature is categorical) + * @param bins Bins for features, of size (numFeatures, numBins). + * @param categoricalFeaturesInfo Map over categorical features: feature index --> feature arity + */ + private def findBin( + featureIndex: Int, + labeledPoint: LabeledPoint, + isFeatureContinuous: Boolean, + isUnorderedFeature: Boolean, + bins: Array[Array[Bin]], + categoricalFeaturesInfo: Map[Int, Int]): Int = { + + /** + * Binary search helper method for continuous feature. + */ + def binarySearchForBins(): Int = { + val binForFeatures = bins(featureIndex) + val feature = labeledPoint.features(featureIndex) + var left = 0 + var right = binForFeatures.length - 1 + while (left <= right) { + val mid = left + (right - left) / 2 + val bin = binForFeatures(mid) + val lowThreshold = bin.lowSplit.threshold + val highThreshold = bin.highSplit.threshold + if ((lowThreshold < feature) && (highThreshold >= feature)) { + return mid + } else if (lowThreshold >= feature) { + right = mid - 1 + } else { + left = mid + 1 + } + } + -1 + } + + /** + * Sequential search helper method to find bin for categorical feature in multiclass + * classification. The category is returned since each category can belong to multiple + * splits. The actual left/right child allocation per split is performed in the + * sequential phase of the bin aggregate operation. + */ + def sequentialBinSearchForUnorderedCategoricalFeatureInClassification(): Int = { + labeledPoint.features(featureIndex).toInt + } + + /** + * Sequential search helper method to find bin for categorical feature + * (for classification and regression). + */ + def sequentialBinSearchForOrderedCategoricalFeature(): Int = { + val featureCategories = categoricalFeaturesInfo(featureIndex) + val featureValue = labeledPoint.features(featureIndex) + var binIndex = 0 + while (binIndex < featureCategories) { + val bin = bins(featureIndex)(binIndex) + val categories = bin.highSplit.categories + if (categories.contains(featureValue)) { + return binIndex + } + binIndex += 1 + } + if (featureValue < 0 || featureValue >= featureCategories) { + throw new IllegalArgumentException( + s"DecisionTree given invalid data:" + + s" Feature $featureIndex is categorical with values in" + + s" {0,...,${featureCategories - 1}," + + s" but a data point gives it value $featureValue.\n" + + " Bad data point: " + labeledPoint.toString) + } + -1 + } + + if (isFeatureContinuous) { + // Perform binary search for finding bin for continuous features. + val binIndex = binarySearchForBins() + if (binIndex == -1) { + throw new RuntimeException("No bin was found for continuous feature." + + " This error can occur when given invalid data values (such as NaN)." + + s" Feature index: $featureIndex. Feature value: ${labeledPoint.features(featureIndex)}") + } + binIndex + } else { + // Perform sequential search to find bin for categorical features. + val binIndex = if (isUnorderedFeature) { + sequentialBinSearchForUnorderedCategoricalFeatureInClassification() + } else { + sequentialBinSearchForOrderedCategoricalFeature() + } + if (binIndex == -1) { + throw new RuntimeException("No bin was found for categorical feature." + + " This error can occur when given invalid data values (such as NaN)." + + s" Feature index: $featureIndex. Feature value: ${labeledPoint.features(featureIndex)}") + } + binIndex + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 70ca7c8a266f2..a5c49a38dc08f 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -21,11 +21,12 @@ import scala.collection.JavaConverters._ import org.scalatest.FunSuite -import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance} -import org.apache.spark.mllib.tree.model.{DecisionTreeModel, Filter, Split} -import org.apache.spark.mllib.tree.configuration.{FeatureType, Strategy} import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.FeatureType._ +import org.apache.spark.mllib.tree.configuration.{FeatureType, Strategy} +import org.apache.spark.mllib.tree.impl.TreePoint +import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance} +import org.apache.spark.mllib.tree.model.{DecisionTreeModel, Filter, Split} import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.LocalSparkContext import org.apache.spark.mllib.regression.LabeledPoint @@ -41,7 +42,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { prediction != expected.label } val accuracy = (input.length - numOffPredictions).toDouble / input.length - assert(accuracy >= requiredAccuracy) + assert(accuracy >= requiredAccuracy, + s"validateClassifier calculated accuracy $accuracy but required $requiredAccuracy.") } def validateRegressor( @@ -54,7 +56,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { err * err }.sum val mse = squaredError / input.length - assert(mse <= requiredMSE) + assert(mse <= requiredMSE, s"validateRegressor calculated MSE $mse but required $requiredMSE.") } test("split and bin calculation") { @@ -427,7 +429,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { maxBins = 100, categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) - val bestSplits = DecisionTree.findBestSplits(rdd, new Array(7), strategy, 0, + val treeInput = TreePoint.convertToTreeRDD(rdd, strategy, bins) + val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(7), strategy, 0, Array[List[Filter]](), splits, bins, 10) val split = bestSplits(0)._1 @@ -454,7 +457,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { maxBins = 100, categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) val (splits, bins) = DecisionTree.findSplitsBins(rdd,strategy) - val bestSplits = DecisionTree.findBestSplits(rdd, new Array(7), strategy, 0, + val treeInput = TreePoint.convertToTreeRDD(rdd, strategy, bins) + val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(7), strategy, 0, Array[List[Filter]](), splits, bins, 10) val split = bestSplits(0)._1 @@ -499,7 +503,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(splits(0).length === 99) assert(bins(0).length === 100) - val bestSplits = DecisionTree.findBestSplits(rdd, new Array(7), strategy, 0, + val treeInput = TreePoint.convertToTreeRDD(rdd, strategy, bins) + val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(7), strategy, 0, Array[List[Filter]](), splits, bins, 10) assert(bestSplits.length === 1) assert(bestSplits(0)._1.feature === 0) @@ -521,7 +526,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(splits(0).length === 99) assert(bins(0).length === 100) - val bestSplits = DecisionTree.findBestSplits(rdd, Array(0.0), strategy, 0, + val treeInput = TreePoint.convertToTreeRDD(rdd, strategy, bins) + val bestSplits = DecisionTree.findBestSplits(treeInput, Array(0.0), strategy, 0, Array[List[Filter]](), splits, bins, 10) assert(bestSplits.length === 1) assert(bestSplits(0)._1.feature === 0) @@ -544,7 +550,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(splits(0).length === 99) assert(bins(0).length === 100) - val bestSplits = DecisionTree.findBestSplits(rdd, Array(0.0), strategy, 0, + val treeInput = TreePoint.convertToTreeRDD(rdd, strategy, bins) + val bestSplits = DecisionTree.findBestSplits(treeInput, Array(0.0), strategy, 0, Array[List[Filter]](), splits, bins, 10) assert(bestSplits.length === 1) assert(bestSplits(0)._1.feature === 0) @@ -567,7 +574,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(splits(0).length === 99) assert(bins(0).length === 100) - val bestSplits = DecisionTree.findBestSplits(rdd, Array(0.0), strategy, 0, + val treeInput = TreePoint.convertToTreeRDD(rdd, strategy, bins) + val bestSplits = DecisionTree.findBestSplits(treeInput, Array(0.0), strategy, 0, Array[List[Filter]](), splits, bins, 10) assert(bestSplits.length === 1) assert(bestSplits(0)._1.feature === 0) @@ -596,7 +604,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val parentImpurities = Array(0.5, 0.5, 0.5) // Single group second level tree construction. - val bestSplits = DecisionTree.findBestSplits(rdd, parentImpurities, strategy, 1, filters, + val treeInput = TreePoint.convertToTreeRDD(rdd, strategy, bins) + val bestSplits = DecisionTree.findBestSplits(treeInput, parentImpurities, strategy, 1, filters, splits, bins, 10) assert(bestSplits.length === 2) assert(bestSplits(0)._2.gain > 0) @@ -604,7 +613,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { // maxLevelForSingleGroup parameter is set to 0 to force splitting into groups for second // level tree construction. - val bestSplitsWithGroups = DecisionTree.findBestSplits(rdd, parentImpurities, strategy, 1, + val bestSplitsWithGroups = DecisionTree.findBestSplits(treeInput, parentImpurities, strategy, 1, filters, splits, bins, 0) assert(bestSplitsWithGroups.length === 2) assert(bestSplitsWithGroups(0)._2.gain > 0) @@ -630,7 +639,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) assert(strategy.isMulticlassClassification) val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) - val bestSplits = DecisionTree.findBestSplits(input, new Array(31), strategy, 0, + val treeInput = TreePoint.convertToTreeRDD(input, strategy, bins) + val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), strategy, 0, Array[List[Filter]](), splits, bins, 10) assert(bestSplits.length === 1) @@ -689,7 +699,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(model.depth === 1) val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) - val bestSplits = DecisionTree.findBestSplits(input, new Array(31), strategy, 0, + val treeInput = TreePoint.convertToTreeRDD(input, strategy, bins) + val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), strategy, 0, Array[List[Filter]](), splits, bins, 10) assert(bestSplits.length === 1) @@ -714,7 +725,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { validateClassifier(model, arr, 0.9) val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) - val bestSplits = DecisionTree.findBestSplits(input, new Array(31), strategy, 0, + val treeInput = TreePoint.convertToTreeRDD(input, strategy, bins) + val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), strategy, 0, Array[List[Filter]](), splits, bins, 10) assert(bestSplits.length === 1) @@ -738,7 +750,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { validateClassifier(model, arr, 0.9) val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) - val bestSplits = DecisionTree.findBestSplits(input, new Array(31), strategy, 0, + val treeInput = TreePoint.convertToTreeRDD(input, strategy, bins) + val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), strategy, 0, Array[List[Filter]](), splits, bins, 10) assert(bestSplits.length === 1) @@ -757,7 +770,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 10, 1 -> 10)) assert(strategy.isMulticlassClassification) val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) - val bestSplits = DecisionTree.findBestSplits(input, new Array(31), strategy, 0, + val treeInput = TreePoint.convertToTreeRDD(input, strategy, bins) + val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), strategy, 0, Array[List[Filter]](), splits, bins, 10) assert(bestSplits.length === 1) From 077213bae09737ccb904f07b2766d43bb0734477 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 15 Aug 2014 17:04:15 -0700 Subject: [PATCH 150/231] [SPARK-3046] use executor's class loader as the default serializer classloader The serializer is not always used in an executor thread (e.g. connection manager, broadcast), in which case the classloader might not have the user jar set, leading to corruption in deserialization. https://issues.apache.org/jira/browse/SPARK-3046 https://issues.apache.org/jira/browse/SPARK-2878 Author: Reynold Xin Closes #1972 from rxin/kryoBug and squashes the following commits: c1c7bf0 [Reynold Xin] Made change to JavaSerializer. 7204c33 [Reynold Xin] Added imports back. d879e67 [Reynold Xin] [SPARK-3046] use executor's class loader as the default serializer class loader. (cherry picked from commit cc3648774e9a744850107bb187f2828d447e0a48) Signed-off-by: Reynold Xin --- .../org/apache/spark/executor/Executor.scala | 3 + .../spark/serializer/JavaSerializer.scala | 9 ++- .../spark/serializer/KryoSerializer.scala | 9 ++- .../apache/spark/serializer/Serializer.scala | 17 +++++ .../KryoSerializerDistributedSuite.scala | 71 +++++++++++++++++++ .../serializer/KryoSerializerSuite.scala | 23 +++++- 6 files changed, 128 insertions(+), 4 deletions(-) create mode 100644 core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index eac1f2326a29d..fb3f7bd54bbfa 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -99,6 +99,9 @@ private[spark] class Executor( private val urlClassLoader = createClassLoader() private val replClassLoader = addReplClassLoaderIfNeeded(urlClassLoader) + // Set the classloader for serializer + env.serializer.setDefaultClassLoader(urlClassLoader) + // Akka's message frame size. If task result is bigger than this, we use the block manager // to send the result back. private val akkaFrameSize = AkkaUtils.maxFrameSizeBytes(conf) diff --git a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala index 34bc3124097bb..af33a2f2ca3e1 100644 --- a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala @@ -63,7 +63,9 @@ extends DeserializationStream { def close() { objIn.close() } } -private[spark] class JavaSerializerInstance(counterReset: Int) extends SerializerInstance { +private[spark] class JavaSerializerInstance(counterReset: Int, defaultClassLoader: ClassLoader) + extends SerializerInstance { + def serialize[T: ClassTag](t: T): ByteBuffer = { val bos = new ByteArrayOutputStream() val out = serializeStream(bos) @@ -109,7 +111,10 @@ private[spark] class JavaSerializerInstance(counterReset: Int) extends Serialize class JavaSerializer(conf: SparkConf) extends Serializer with Externalizable { private var counterReset = conf.getInt("spark.serializer.objectStreamReset", 100) - def newInstance(): SerializerInstance = new JavaSerializerInstance(counterReset) + override def newInstance(): SerializerInstance = { + val classLoader = defaultClassLoader.getOrElse(Thread.currentThread.getContextClassLoader) + new JavaSerializerInstance(counterReset, classLoader) + } override def writeExternal(out: ObjectOutput) { out.writeInt(counterReset) diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index 85944eabcfefc..99682220b4ab5 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -61,7 +61,9 @@ class KryoSerializer(conf: SparkConf) val instantiator = new EmptyScalaKryoInstantiator val kryo = instantiator.newKryo() kryo.setRegistrationRequired(registrationRequired) - val classLoader = Thread.currentThread.getContextClassLoader + + val oldClassLoader = Thread.currentThread.getContextClassLoader + val classLoader = defaultClassLoader.getOrElse(Thread.currentThread.getContextClassLoader) // Allow disabling Kryo reference tracking if user knows their object graphs don't have loops. // Do this before we invoke the user registrator so the user registrator can override this. @@ -84,10 +86,15 @@ class KryoSerializer(conf: SparkConf) try { val reg = Class.forName(regCls, true, classLoader).newInstance() .asInstanceOf[KryoRegistrator] + + // Use the default classloader when calling the user registrator. + Thread.currentThread.setContextClassLoader(classLoader) reg.registerClasses(kryo) } catch { case e: Exception => throw new SparkException(s"Failed to invoke $regCls", e) + } finally { + Thread.currentThread.setContextClassLoader(oldClassLoader) } } diff --git a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala index f2f5cea469c61..e674438c8176c 100644 --- a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala @@ -44,6 +44,23 @@ import org.apache.spark.util.{ByteBufferInputStream, NextIterator} */ @DeveloperApi trait Serializer { + + /** + * Default ClassLoader to use in deserialization. Implementations of [[Serializer]] should + * make sure it is using this when set. + */ + @volatile protected var defaultClassLoader: Option[ClassLoader] = None + + /** + * Sets a class loader for the serializer to use in deserialization. + * + * @return this Serializer object + */ + def setDefaultClassLoader(classLoader: ClassLoader): Serializer = { + defaultClassLoader = Some(classLoader) + this + } + def newInstance(): SerializerInstance } diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala new file mode 100644 index 0000000000000..11e8c9c4cb37f --- /dev/null +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.serializer + +import org.apache.spark.util.Utils + +import com.esotericsoftware.kryo.Kryo +import org.scalatest.FunSuite + +import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkEnv, TestUtils} +import org.apache.spark.SparkContext._ +import org.apache.spark.serializer.KryoDistributedTest._ + +class KryoSerializerDistributedSuite extends FunSuite { + + test("kryo objects are serialised consistently in different processes") { + val conf = new SparkConf(false) + conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") + conf.set("spark.kryo.registrator", classOf[AppJarRegistrator].getName) + conf.set("spark.task.maxFailures", "1") + + val jar = TestUtils.createJarWithClasses(List(AppJarRegistrator.customClassName)) + conf.setJars(List(jar.getPath)) + + val sc = new SparkContext("local-cluster[2,1,512]", "test", conf) + val original = Thread.currentThread.getContextClassLoader + val loader = new java.net.URLClassLoader(Array(jar), Utils.getContextOrSparkClassLoader) + SparkEnv.get.serializer.setDefaultClassLoader(loader) + + val cachedRDD = sc.parallelize((0 until 10).map((_, new MyCustomClass)), 3).cache() + + // Randomly mix the keys so that the join below will require a shuffle with each partition + // sending data to multiple other partitions. + val shuffledRDD = cachedRDD.map { case (i, o) => (i * i * i - 10 * i * i, o)} + + // Join the two RDDs, and force evaluation + assert(shuffledRDD.join(cachedRDD).collect().size == 1) + + LocalSparkContext.stop(sc) + } +} + +object KryoDistributedTest { + class MyCustomClass + + class AppJarRegistrator extends KryoRegistrator { + override def registerClasses(k: Kryo) { + val classLoader = Thread.currentThread.getContextClassLoader + k.register(Class.forName(AppJarRegistrator.customClassName, true, classLoader)) + } + } + + object AppJarRegistrator { + val customClassName = "KryoSerializerDistributedSuiteCustomClass" + } +} diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala index 3bf9efebb39d2..a579fd50bd9e4 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala @@ -23,7 +23,7 @@ import scala.reflect.ClassTag import com.esotericsoftware.kryo.Kryo import org.scalatest.FunSuite -import org.apache.spark.SharedSparkContext +import org.apache.spark.{SparkConf, SharedSparkContext} import org.apache.spark.serializer.KryoTest._ class KryoSerializerSuite extends FunSuite with SharedSparkContext { @@ -217,8 +217,29 @@ class KryoSerializerSuite extends FunSuite with SharedSparkContext { val thrown = intercept[SparkException](new KryoSerializer(conf).newInstance()) assert(thrown.getMessage.contains("Failed to invoke this.class.does.not.exist")) } + + test("default class loader can be set by a different thread") { + val ser = new KryoSerializer(new SparkConf) + + // First serialize the object + val serInstance = ser.newInstance() + val bytes = serInstance.serialize(new ClassLoaderTestingObject) + + // Deserialize the object to make sure normal deserialization works + serInstance.deserialize[ClassLoaderTestingObject](bytes) + + // Set a special, broken ClassLoader and make sure we get an exception on deserialization + ser.setDefaultClassLoader(new ClassLoader() { + override def loadClass(name: String) = throw new UnsupportedOperationException + }) + intercept[UnsupportedOperationException] { + ser.newInstance().deserialize[ClassLoaderTestingObject](bytes) + } + } } +class ClassLoaderTestingObject + class KryoSerializerResizableOutputSuite extends FunSuite { import org.apache.spark.SparkConf import org.apache.spark.SparkContext From c085011cac4df1bf4cbaef00a8b921ace6e3123b Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Fri, 15 Aug 2014 21:04:29 -0700 Subject: [PATCH 151/231] [SPARK-3078][MLLIB] Make LRWithLBFGS API consistent with others Should ask users to set parameters through the optimizer. dbtsai Author: Xiangrui Meng Closes #1973 from mengxr/lr-lbfgs and squashes the following commits: e3efbb1 [Xiangrui Meng] fix tests 21b3579 [Xiangrui Meng] fix method name 641eea4 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into lr-lbfgs 456ab7c [Xiangrui Meng] update LRWithLBFGS (cherry picked from commit 5d25c0b74f6397d78164b96afb8b8cbb1b15cfbd) Signed-off-by: Xiangrui Meng --- .../examples/mllib/BinaryClassification.scala | 8 ++-- .../classification/LogisticRegression.scala | 40 +++---------------- .../spark/mllib/optimization/LBFGS.scala | 9 +++++ .../LogisticRegressionSuite.scala | 5 ++- .../spark/mllib/optimization/LBFGSSuite.scala | 24 +++++------ 5 files changed, 33 insertions(+), 53 deletions(-) diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala index 56b02b65d8724..a6f78d2441db1 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala @@ -21,7 +21,7 @@ import org.apache.log4j.{Level, Logger} import scopt.OptionParser import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.mllib.classification.{LogisticRegressionWithSGD, SVMWithSGD} +import org.apache.spark.mllib.classification.{LogisticRegressionWithLBFGS, SVMWithSGD} import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics import org.apache.spark.mllib.util.MLUtils import org.apache.spark.mllib.optimization.{SquaredL2Updater, L1Updater} @@ -66,7 +66,8 @@ object BinaryClassification { .text("number of iterations") .action((x, c) => c.copy(numIterations = x)) opt[Double]("stepSize") - .text(s"initial step size, default: ${defaultParams.stepSize}") + .text("initial step size (ignored by logistic regression), " + + s"default: ${defaultParams.stepSize}") .action((x, c) => c.copy(stepSize = x)) opt[String]("algorithm") .text(s"algorithm (${Algorithm.values.mkString(",")}), " + @@ -125,10 +126,9 @@ object BinaryClassification { val model = params.algorithm match { case LR => - val algorithm = new LogisticRegressionWithSGD() + val algorithm = new LogisticRegressionWithLBFGS() algorithm.optimizer .setNumIterations(params.numIterations) - .setStepSize(params.stepSize) .setUpdater(updater) .setRegParam(params.regParam) algorithm.run(training).clearThreshold() diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala index 6790c86f651b4..486bdbfa9cb47 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala @@ -73,6 +73,8 @@ class LogisticRegressionModel ( /** * Train a classification model for Logistic Regression using Stochastic Gradient Descent. * NOTE: Labels used in Logistic Regression should be {0, 1} + * + * Using [[LogisticRegressionWithLBFGS]] is recommended over this. */ class LogisticRegressionWithSGD private ( private var stepSize: Double, @@ -191,51 +193,19 @@ object LogisticRegressionWithSGD { /** * Train a classification model for Logistic Regression using Limited-memory BFGS. + * Standard feature scaling and L2 regularization are used by default. * NOTE: Labels used in Logistic Regression should be {0, 1} */ -class LogisticRegressionWithLBFGS private ( - private var convergenceTol: Double, - private var maxNumIterations: Int, - private var regParam: Double) +class LogisticRegressionWithLBFGS extends GeneralizedLinearAlgorithm[LogisticRegressionModel] with Serializable { - /** - * Construct a LogisticRegression object with default parameters - */ - def this() = this(1E-4, 100, 0.0) - this.setFeatureScaling(true) - private val gradient = new LogisticGradient() - private val updater = new SimpleUpdater() - // Have to return new LBFGS object every time since users can reset the parameters anytime. - override def optimizer = new LBFGS(gradient, updater) - .setNumCorrections(10) - .setConvergenceTol(convergenceTol) - .setMaxNumIterations(maxNumIterations) - .setRegParam(regParam) + override val optimizer = new LBFGS(new LogisticGradient, new SquaredL2Updater) override protected val validators = List(DataValidators.binaryLabelValidator) - /** - * Set the convergence tolerance of iterations for L-BFGS. Default 1E-4. - * Smaller value will lead to higher accuracy with the cost of more iterations. - */ - def setConvergenceTol(convergenceTol: Double): this.type = { - this.convergenceTol = convergenceTol - this - } - - /** - * Set the maximal number of iterations for L-BFGS. Default 100. - */ - def setNumIterations(numIterations: Int): this.type = { - this.maxNumIterations = numIterations - this - } - override protected def createModel(weights: Vector, intercept: Double) = { new LogisticRegressionModel(weights, intercept) } - } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala index 033fe44f34f3c..d16d0daf08565 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala @@ -69,8 +69,17 @@ class LBFGS(private var gradient: Gradient, private var updater: Updater) /** * Set the maximal number of iterations for L-BFGS. Default 100. + * @deprecated use [[LBFGS#setNumIterations]] instead */ + @deprecated("use setNumIterations instead", "1.1.0") def setMaxNumIterations(iters: Int): this.type = { + this.setNumIterations(iters) + } + + /** + * Set the maximal number of iterations for L-BFGS. Default 100. + */ + def setNumIterations(iters: Int): this.type = { this.maxNumIterations = iters this } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala index bc05b2046878f..862178694a50e 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala @@ -272,8 +272,9 @@ class LogisticRegressionClusterSuite extends FunSuite with LocalClusterSparkCont }.cache() // If we serialize data directly in the task closure, the size of the serialized task would be // greater than 1MB and hence Spark would throw an error. - val model = - (new LogisticRegressionWithLBFGS().setIntercept(true).setNumIterations(2)).run(points) + val lr = new LogisticRegressionWithLBFGS().setIntercept(true) + lr.optimizer.setNumIterations(2) + val model = lr.run(points) val predictions = model.predict(points.map(_.features)) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala index 5f4c24115ac80..ccba004baa007 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala @@ -55,7 +55,7 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with Matchers { val initialWeightsWithIntercept = Vectors.dense(1.0 +: initialWeights.toArray) val convergenceTol = 1e-12 - val maxNumIterations = 10 + val numIterations = 10 val (_, loss) = LBFGS.runLBFGS( dataRDD, @@ -63,7 +63,7 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with Matchers { simpleUpdater, numCorrections, convergenceTol, - maxNumIterations, + numIterations, regParam, initialWeightsWithIntercept) @@ -99,7 +99,7 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with Matchers { // Prepare another non-zero weights to compare the loss in the first iteration. val initialWeightsWithIntercept = Vectors.dense(0.3, 0.12) val convergenceTol = 1e-12 - val maxNumIterations = 10 + val numIterations = 10 val (weightLBFGS, lossLBFGS) = LBFGS.runLBFGS( dataRDD, @@ -107,7 +107,7 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with Matchers { squaredL2Updater, numCorrections, convergenceTol, - maxNumIterations, + numIterations, regParam, initialWeightsWithIntercept) @@ -140,10 +140,10 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with Matchers { /** * For the first run, we set the convergenceTol to 0.0, so that the algorithm will - * run up to the maxNumIterations which is 8 here. + * run up to the numIterations which is 8 here. */ val initialWeightsWithIntercept = Vectors.dense(0.0, 0.0) - val maxNumIterations = 8 + val numIterations = 8 var convergenceTol = 0.0 val (_, lossLBFGS1) = LBFGS.runLBFGS( @@ -152,7 +152,7 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with Matchers { squaredL2Updater, numCorrections, convergenceTol, - maxNumIterations, + numIterations, regParam, initialWeightsWithIntercept) @@ -167,7 +167,7 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with Matchers { squaredL2Updater, numCorrections, convergenceTol, - maxNumIterations, + numIterations, regParam, initialWeightsWithIntercept) @@ -182,7 +182,7 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with Matchers { squaredL2Updater, numCorrections, convergenceTol, - maxNumIterations, + numIterations, regParam, initialWeightsWithIntercept) @@ -200,12 +200,12 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with Matchers { // Prepare another non-zero weights to compare the loss in the first iteration. val initialWeightsWithIntercept = Vectors.dense(0.3, 0.12) val convergenceTol = 1e-12 - val maxNumIterations = 10 + val numIterations = 10 val lbfgsOptimizer = new LBFGS(gradient, squaredL2Updater) .setNumCorrections(numCorrections) .setConvergenceTol(convergenceTol) - .setMaxNumIterations(maxNumIterations) + .setNumIterations(numIterations) .setRegParam(regParam) val weightLBFGS = lbfgsOptimizer.optimize(dataRDD, initialWeightsWithIntercept) @@ -241,7 +241,7 @@ class LBFGSClusterSuite extends FunSuite with LocalClusterSparkContext { val lbfgs = new LBFGS(new LogisticGradient, new SquaredL2Updater) .setNumCorrections(1) .setConvergenceTol(1e-12) - .setMaxNumIterations(1) + .setNumIterations(1) .setRegParam(1.0) val random = new Random(0) // If we serialize data directly in the task closure, the size of the serialized task would be From ce06d7f45bc551f6121c382b0833e01b8a83f636 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Fri, 15 Aug 2014 21:07:55 -0700 Subject: [PATCH 152/231] [SPARK-3001][MLLIB] Improve Spearman's correlation The current implementation requires sorting individual columns, which could be done with a global sort. result on a 32-node cluster: m | n | prev | this ---|---|-------|----- 1000000 | 50 | 55s | 9s 10000000 | 50 | 97s | 76s 1000000 | 100 | 119s | 15s Author: Xiangrui Meng Closes #1917 from mengxr/spearman and squashes the following commits: 4d5d262 [Xiangrui Meng] remove unused import 85c48de [Xiangrui Meng] minor updates a048d0c [Xiangrui Meng] remove cache and set a limit to cachedIds b98bb18 [Xiangrui Meng] add comments 0846e07 [Xiangrui Meng] first version (cherry picked from commit 2e069ca6560bf7ab07bd019f9530b42f4fe45014) Signed-off-by: Xiangrui Meng --- .../correlation/SpearmanCorrelation.scala | 120 ++++++------------ 1 file changed, 42 insertions(+), 78 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/correlation/SpearmanCorrelation.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/correlation/SpearmanCorrelation.scala index 9bd0c2cd05de4..4a6c677f06d28 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/correlation/SpearmanCorrelation.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/correlation/SpearmanCorrelation.scala @@ -19,10 +19,10 @@ package org.apache.spark.mllib.stat.correlation import scala.collection.mutable.ArrayBuffer -import org.apache.spark.{Logging, HashPartitioner} +import org.apache.spark.Logging import org.apache.spark.SparkContext._ -import org.apache.spark.mllib.linalg.{DenseVector, Matrix, Vector} -import org.apache.spark.rdd.{CoGroupedRDD, RDD} +import org.apache.spark.mllib.linalg.{Matrix, Vector, Vectors} +import org.apache.spark.rdd.RDD /** * Compute Spearman's correlation for two RDDs of the type RDD[Double] or the correlation matrix @@ -43,87 +43,51 @@ private[stat] object SpearmanCorrelation extends Correlation with Logging { /** * Compute Spearman's correlation matrix S, for the input matrix, where S(i, j) is the * correlation between column i and j. - * - * Input RDD[Vector] should be cached or checkpointed if possible since it would be split into - * numCol RDD[Double]s, each of which sorted, and the joined back into a single RDD[Vector]. */ override def computeCorrelationMatrix(X: RDD[Vector]): Matrix = { - val indexed = X.zipWithUniqueId() - - val numCols = X.first.size - if (numCols > 50) { - logWarning("Computing the Spearman correlation matrix can be slow for large RDDs with more" - + " than 50 columns.") - } - val ranks = new Array[RDD[(Long, Double)]](numCols) - - // Note: we use a for loop here instead of a while loop with a single index variable - // to avoid race condition caused by closure serialization - for (k <- 0 until numCols) { - val column = indexed.map { case (vector, index) => (vector(k), index) } - ranks(k) = getRanks(column) + // ((columnIndex, value), rowUid) + val colBased = X.zipWithUniqueId().flatMap { case (vec, uid) => + vec.toArray.view.zipWithIndex.map { case (v, j) => + ((j, v), uid) + } } - - val ranksMat: RDD[Vector] = makeRankMatrix(ranks, X) - PearsonCorrelation.computeCorrelationMatrix(ranksMat) - } - - /** - * Compute the ranks for elements in the input RDD, using the average method for ties. - * - * With the average method, elements with the same value receive the same rank that's computed - * by taking the average of their positions in the sorted list. - * e.g. ranks([2, 1, 0, 2]) = [2.5, 1.0, 0.0, 2.5] - * Note that positions here are 0-indexed, instead of the 1-indexed as in the definition for - * ranks in the standard definition for Spearman's correlation. This does not affect the final - * results and is slightly more performant. - * - * @param indexed RDD[(Double, Long)] containing pairs of the format (originalValue, uniqueId) - * @return RDD[(Long, Double)] containing pairs of the format (uniqueId, rank), where uniqueId is - * copied from the input RDD. - */ - private def getRanks(indexed: RDD[(Double, Long)]): RDD[(Long, Double)] = { - // Get elements' positions in the sorted list for computing average rank for duplicate values - val sorted = indexed.sortByKey().zipWithIndex() - - val ranks: RDD[(Long, Double)] = sorted.mapPartitions { iter => - // add an extra element to signify the end of the list so that flatMap can flush the last - // batch of duplicates - val end = -1L - val padded = iter ++ Iterator[((Double, Long), Long)](((Double.NaN, end), end)) - val firstEntry = padded.next() - var lastVal = firstEntry._1._1 - var firstRank = firstEntry._2.toDouble - val idBuffer = ArrayBuffer(firstEntry._1._2) - padded.flatMap { case ((v, id), rank) => - if (v == lastVal && id != end) { - idBuffer += id - Iterator.empty - } else { - val entries = if (idBuffer.size == 1) { - Iterator((idBuffer(0), firstRank)) - } else { - val averageRank = firstRank + (idBuffer.size - 1.0) / 2.0 - idBuffer.map(id => (id, averageRank)) - } - lastVal = v - firstRank = rank - idBuffer.clear() - idBuffer += id - entries + // global sort by (columnIndex, value) + val sorted = colBased.sortByKey() + // assign global ranks (using average ranks for tied values) + val globalRanks = sorted.zipWithIndex().mapPartitions { iter => + var preCol = -1 + var preVal = Double.NaN + var startRank = -1.0 + var cachedUids = ArrayBuffer.empty[Long] + val flush: () => Iterable[(Long, (Int, Double))] = () => { + val averageRank = startRank + (cachedUids.size - 1) / 2.0 + val output = cachedUids.map { uid => + (uid, (preCol, averageRank)) } + cachedUids.clear() + output } + iter.flatMap { case (((j, v), uid), rank) => + // If we see a new value or cachedUids is too big, we flush ids with their average rank. + if (j != preCol || v != preVal || cachedUids.size >= 10000000) { + val output = flush() + preCol = j + preVal = v + startRank = rank + cachedUids += uid + output + } else { + cachedUids += uid + Iterator.empty + } + } ++ flush() } - ranks - } - - private def makeRankMatrix(ranks: Array[RDD[(Long, Double)]], input: RDD[Vector]): RDD[Vector] = { - val partitioner = new HashPartitioner(input.partitions.size) - val cogrouped = new CoGroupedRDD[Long](ranks, partitioner) - cogrouped.map { - case (_, values: Array[Iterable[_]]) => - val doubles = values.asInstanceOf[Array[Iterable[Double]]] - new DenseVector(doubles.flatten.toArray) + // Replace values in the input matrix by their ranks compared with values in the same column. + // Note that shifting all ranks in a column by a constant value doesn't affect result. + val groupedRanks = globalRanks.groupByKey().map { case (uid, iter) => + // sort by column index and then convert values to a vector + Vectors.dense(iter.toSeq.sortBy(_._1).map(_._2).toArray) } + PearsonCorrelation.computeCorrelationMatrix(groupedRanks) } } From 2541537217fd3f73e494c98d4c5e379723fe0199 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Fri, 15 Aug 2014 22:55:32 -0700 Subject: [PATCH 153/231] [SPARK-3015] Block on cleaning tasks to prevent Akka timeouts More detail on the issue is described in [SPARK-3015](https://issues.apache.org/jira/browse/SPARK-3015), but the TLDR is if we send too many blocking Akka messages that are dependent on each other in quick successions, then we end up causing a few of these messages to time out and ultimately kill the executors. As of #1498, we broadcast each RDD whether or not it is persisted. This means if we create many RDDs (each of which becomes a broadcast) and the driver performs a GC that cleans up all of these broadcast blocks, then we end up sending many `RemoveBroadcast` messages in parallel and trigger the chain of blocking messages at high frequencies. We do not know of the Akka-level root cause yet, so this is intended to be a temporary solution until we identify the real issue. I have done some preliminary testing of enabling blocking and observed that the queue length remains quite low (< 1000) even under very intensive workloads. In the long run, we should do something more sophisticated to allow a limited degree of parallelism through batching clean up tasks or processing them in a sliding window. In the longer run, we should clean up the whole `BlockManager*` message passing interface to avoid unnecessarily awaiting on futures created from Akka asks. tdas pwendell mengxr Author: Andrew Or Closes #1931 from andrewor14/reference-blocking and squashes the following commits: d0f7195 [Andrew Or] Merge branch 'master' of github.com:apache/spark into reference-blocking ce9daf5 [Andrew Or] Remove logic for logging queue length 111192a [Andrew Or] Add missing space in log message (minor) a183b83 [Andrew Or] Switch order of code blocks (minor) 9fd1fe6 [Andrew Or] Remove outdated log 104b366 [Andrew Or] Use the actual reference queue length 0b7e768 [Andrew Or] Block on cleaning tasks by default + log error on queue full (cherry picked from commit c9da466edb83e45a159ccc17c68856a511b9e8b7) Signed-off-by: Patrick Wendell --- .../main/scala/org/apache/spark/ContextCleaner.scala | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ContextCleaner.scala b/core/src/main/scala/org/apache/spark/ContextCleaner.scala index bf3c3a6ceb5ef..3848734d6f639 100644 --- a/core/src/main/scala/org/apache/spark/ContextCleaner.scala +++ b/core/src/main/scala/org/apache/spark/ContextCleaner.scala @@ -66,10 +66,15 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { /** * Whether the cleaning thread will block on cleanup tasks. - * This is set to true only for tests. + * + * Due to SPARK-3015, this is set to true by default. This is intended to be only a temporary + * workaround for the issue, which is ultimately caused by the way the BlockManager actors + * issue inter-dependent blocking Akka messages to each other at high frequencies. This happens, + * for instance, when the driver performs a GC and cleans up all broadcast blocks that are no + * longer in scope. */ private val blockOnCleanupTasks = sc.conf.getBoolean( - "spark.cleaner.referenceTracking.blocking", false) + "spark.cleaner.referenceTracking.blocking", true) @volatile private var stopped = false @@ -174,9 +179,6 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { private def blockManagerMaster = sc.env.blockManager.master private def broadcastManager = sc.env.broadcastManager private def mapOutputTrackerMaster = sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster] - - // Used for testing. These methods explicitly blocks until cleanup is completed - // to ensure that more reliable testing. } private object ContextCleaner { From fcf30cdc558aff4c615e4d8f0bbe30e39a0448e4 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 15 Aug 2014 23:12:34 -0700 Subject: [PATCH 154/231] [SPARK-3045] Make Serializer interface Java friendly Author: Reynold Xin Closes #1948 from rxin/kryo and squashes the following commits: a3a80d8 [Reynold Xin] [SPARK-3046] use executor's class loader as the default serializer classloader 3d13277 [Reynold Xin] Reverted that in TestJavaSerializerImpl too. 196f3dc [Reynold Xin] Ok one more commit to revert the classloader change. c49b50c [Reynold Xin] Removed JavaSerializer change. afbf37d [Reynold Xin] Moved the test case also. a2e693e [Reynold Xin] Removed the Kryo bug fix from this pull request. c81bd6c [Reynold Xin] Use defaultClassLoader when executing user specified custom registrator. 68f261e [Reynold Xin] Added license check excludes. 0c28179 [Reynold Xin] [SPARK-3045] Make Serializer interface Java friendly [SPARK-3046] Set executor's class loader as the default serializer class loader (cherry picked from commit a83c7723bf7a90dc6cd5dde98a179303b7542020) Signed-off-by: Reynold Xin --- .../spark/serializer/JavaSerializer.scala | 15 +-- .../spark/serializer/KryoSerializer.scala | 32 +++---- .../apache/spark/serializer/Serializer.scala | 25 ++--- .../apache/spark/serializer/package-info.java | 2 +- .../serializer/TestJavaSerializerImpl.java | 95 +++++++++++++++++++ .../KryoSerializerResizableOutputSuite.scala | 52 ++++++++++ .../serializer/KryoSerializerSuite.scala | 34 +------ project/MimaExcludes.scala | 11 +++ 8 files changed, 193 insertions(+), 73 deletions(-) create mode 100644 core/src/test/java/org/apache/spark/serializer/TestJavaSerializerImpl.java create mode 100644 core/src/test/scala/org/apache/spark/serializer/KryoSerializerResizableOutputSuite.scala diff --git a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala index af33a2f2ca3e1..554a33ce7f1a6 100644 --- a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala @@ -63,10 +63,11 @@ extends DeserializationStream { def close() { objIn.close() } } + private[spark] class JavaSerializerInstance(counterReset: Int, defaultClassLoader: ClassLoader) extends SerializerInstance { - def serialize[T: ClassTag](t: T): ByteBuffer = { + override def serialize[T: ClassTag](t: T): ByteBuffer = { val bos = new ByteArrayOutputStream() val out = serializeStream(bos) out.writeObject(t) @@ -74,23 +75,23 @@ private[spark] class JavaSerializerInstance(counterReset: Int, defaultClassLoade ByteBuffer.wrap(bos.toByteArray) } - def deserialize[T: ClassTag](bytes: ByteBuffer): T = { + override def deserialize[T: ClassTag](bytes: ByteBuffer): T = { val bis = new ByteBufferInputStream(bytes) val in = deserializeStream(bis) - in.readObject().asInstanceOf[T] + in.readObject() } - def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T = { + override def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T = { val bis = new ByteBufferInputStream(bytes) val in = deserializeStream(bis, loader) - in.readObject().asInstanceOf[T] + in.readObject() } - def serializeStream(s: OutputStream): SerializationStream = { + override def serializeStream(s: OutputStream): SerializationStream = { new JavaSerializationStream(s, counterReset) } - def deserializeStream(s: InputStream): DeserializationStream = { + override def deserializeStream(s: InputStream): DeserializationStream = { new JavaDeserializationStream(s, Utils.getContextOrSparkClassLoader) } diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index 99682220b4ab5..87ef9bb0b43c6 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -91,7 +91,7 @@ class KryoSerializer(conf: SparkConf) Thread.currentThread.setContextClassLoader(classLoader) reg.registerClasses(kryo) } catch { - case e: Exception => + case e: Exception => throw new SparkException(s"Failed to invoke $regCls", e) } finally { Thread.currentThread.setContextClassLoader(oldClassLoader) @@ -106,7 +106,7 @@ class KryoSerializer(conf: SparkConf) kryo } - def newInstance(): SerializerInstance = { + override def newInstance(): SerializerInstance = { new KryoSerializerInstance(this) } } @@ -115,20 +115,20 @@ private[spark] class KryoSerializationStream(kryo: Kryo, outStream: OutputStream) extends SerializationStream { val output = new KryoOutput(outStream) - def writeObject[T: ClassTag](t: T): SerializationStream = { + override def writeObject[T: ClassTag](t: T): SerializationStream = { kryo.writeClassAndObject(output, t) this } - def flush() { output.flush() } - def close() { output.close() } + override def flush() { output.flush() } + override def close() { output.close() } } private[spark] class KryoDeserializationStream(kryo: Kryo, inStream: InputStream) extends DeserializationStream { - val input = new KryoInput(inStream) + private val input = new KryoInput(inStream) - def readObject[T: ClassTag](): T = { + override def readObject[T: ClassTag](): T = { try { kryo.readClassAndObject(input).asInstanceOf[T] } catch { @@ -138,31 +138,31 @@ class KryoDeserializationStream(kryo: Kryo, inStream: InputStream) extends Deser } } - def close() { + override def close() { // Kryo's Input automatically closes the input stream it is using. input.close() } } private[spark] class KryoSerializerInstance(ks: KryoSerializer) extends SerializerInstance { - val kryo = ks.newKryo() + private val kryo = ks.newKryo() // Make these lazy vals to avoid creating a buffer unless we use them - lazy val output = ks.newKryoOutput() - lazy val input = new KryoInput() + private lazy val output = ks.newKryoOutput() + private lazy val input = new KryoInput() - def serialize[T: ClassTag](t: T): ByteBuffer = { + override def serialize[T: ClassTag](t: T): ByteBuffer = { output.clear() kryo.writeClassAndObject(output, t) ByteBuffer.wrap(output.toBytes) } - def deserialize[T: ClassTag](bytes: ByteBuffer): T = { + override def deserialize[T: ClassTag](bytes: ByteBuffer): T = { input.setBuffer(bytes.array) kryo.readClassAndObject(input).asInstanceOf[T] } - def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T = { + override def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T = { val oldClassLoader = kryo.getClassLoader kryo.setClassLoader(loader) input.setBuffer(bytes.array) @@ -171,11 +171,11 @@ private[spark] class KryoSerializerInstance(ks: KryoSerializer) extends Serializ obj } - def serializeStream(s: OutputStream): SerializationStream = { + override def serializeStream(s: OutputStream): SerializationStream = { new KryoSerializationStream(kryo, s) } - def deserializeStream(s: InputStream): DeserializationStream = { + override def deserializeStream(s: InputStream): DeserializationStream = { new KryoDeserializationStream(kryo, s) } } diff --git a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala index e674438c8176c..a9144cdd97b8c 100644 --- a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala @@ -43,7 +43,7 @@ import org.apache.spark.util.{ByteBufferInputStream, NextIterator} * They are intended to be used to serialize/de-serialize data within a single Spark application. */ @DeveloperApi -trait Serializer { +abstract class Serializer { /** * Default ClassLoader to use in deserialization. Implementations of [[Serializer]] should @@ -61,10 +61,12 @@ trait Serializer { this } + /** Creates a new [[SerializerInstance]]. */ def newInstance(): SerializerInstance } +@DeveloperApi object Serializer { def getSerializer(serializer: Serializer): Serializer = { if (serializer == null) SparkEnv.get.serializer else serializer @@ -81,7 +83,7 @@ object Serializer { * An instance of a serializer, for use by one thread at a time. */ @DeveloperApi -trait SerializerInstance { +abstract class SerializerInstance { def serialize[T: ClassTag](t: T): ByteBuffer def deserialize[T: ClassTag](bytes: ByteBuffer): T @@ -91,21 +93,6 @@ trait SerializerInstance { def serializeStream(s: OutputStream): SerializationStream def deserializeStream(s: InputStream): DeserializationStream - - def serializeMany[T: ClassTag](iterator: Iterator[T]): ByteBuffer = { - // Default implementation uses serializeStream - val stream = new ByteArrayOutputStream() - serializeStream(stream).writeAll(iterator) - val buffer = ByteBuffer.wrap(stream.toByteArray) - buffer.flip() - buffer - } - - def deserializeMany(buffer: ByteBuffer): Iterator[Any] = { - // Default implementation uses deserializeStream - buffer.rewind() - deserializeStream(new ByteBufferInputStream(buffer)).asIterator - } } /** @@ -113,7 +100,7 @@ trait SerializerInstance { * A stream for writing serialized objects. */ @DeveloperApi -trait SerializationStream { +abstract class SerializationStream { def writeObject[T: ClassTag](t: T): SerializationStream def flush(): Unit def close(): Unit @@ -132,7 +119,7 @@ trait SerializationStream { * A stream for reading serialized objects. */ @DeveloperApi -trait DeserializationStream { +abstract class DeserializationStream { def readObject[T: ClassTag](): T def close(): Unit diff --git a/core/src/main/scala/org/apache/spark/serializer/package-info.java b/core/src/main/scala/org/apache/spark/serializer/package-info.java index 4c0b73ab36a00..207c6e02e4293 100644 --- a/core/src/main/scala/org/apache/spark/serializer/package-info.java +++ b/core/src/main/scala/org/apache/spark/serializer/package-info.java @@ -18,4 +18,4 @@ /** * Pluggable serializers for RDD and shuffle data. */ -package org.apache.spark.serializer; \ No newline at end of file +package org.apache.spark.serializer; diff --git a/core/src/test/java/org/apache/spark/serializer/TestJavaSerializerImpl.java b/core/src/test/java/org/apache/spark/serializer/TestJavaSerializerImpl.java new file mode 100644 index 0000000000000..3d50ab4fabe42 --- /dev/null +++ b/core/src/test/java/org/apache/spark/serializer/TestJavaSerializerImpl.java @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.serializer; + +import java.io.InputStream; +import java.io.OutputStream; +import java.nio.ByteBuffer; + +import scala.Option; +import scala.reflect.ClassTag; + + +/** + * A simple Serializer implementation to make sure the API is Java-friendly. + */ +class TestJavaSerializerImpl extends Serializer { + + @Override + public SerializerInstance newInstance() { + return null; + } + + static class SerializerInstanceImpl extends SerializerInstance { + @Override + public ByteBuffer serialize(T t, ClassTag evidence$1) { + return null; + } + + @Override + public T deserialize(ByteBuffer bytes, ClassLoader loader, ClassTag evidence$1) { + return null; + } + + @Override + public T deserialize(ByteBuffer bytes, ClassTag evidence$1) { + return null; + } + + @Override + public SerializationStream serializeStream(OutputStream s) { + return null; + } + + @Override + public DeserializationStream deserializeStream(InputStream s) { + return null; + } + } + + static class SerializationStreamImpl extends SerializationStream { + + @Override + public SerializationStream writeObject(T t, ClassTag evidence$1) { + return null; + } + + @Override + public void flush() { + + } + + @Override + public void close() { + + } + } + + static class DeserializationStreamImpl extends DeserializationStream { + + @Override + public T readObject(ClassTag evidence$1) { + return null; + } + + @Override + public void close() { + + } + } +} diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerResizableOutputSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerResizableOutputSuite.scala new file mode 100644 index 0000000000000..967c9e9899c9d --- /dev/null +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerResizableOutputSuite.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.serializer + +import org.scalatest.FunSuite + +import org.apache.spark.SparkConf +import org.apache.spark.SparkContext +import org.apache.spark.LocalSparkContext +import org.apache.spark.SparkException + + +class KryoSerializerResizableOutputSuite extends FunSuite { + + // trial and error showed this will not serialize with 1mb buffer + val x = (1 to 400000).toArray + + test("kryo without resizable output buffer should fail on large array") { + val conf = new SparkConf(false) + conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") + conf.set("spark.kryoserializer.buffer.mb", "1") + conf.set("spark.kryoserializer.buffer.max.mb", "1") + val sc = new SparkContext("local", "test", conf) + intercept[SparkException](sc.parallelize(x).collect()) + LocalSparkContext.stop(sc) + } + + test("kryo with resizable output buffer should succeed on large array") { + val conf = new SparkConf(false) + conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") + conf.set("spark.kryoserializer.buffer.mb", "1") + conf.set("spark.kryoserializer.buffer.max.mb", "2") + val sc = new SparkContext("local", "test", conf) + assert(sc.parallelize(x).collect() === x) + LocalSparkContext.stop(sc) + } +} diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala index a579fd50bd9e4..e1e35b688d581 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala @@ -26,6 +26,7 @@ import org.scalatest.FunSuite import org.apache.spark.{SparkConf, SharedSparkContext} import org.apache.spark.serializer.KryoTest._ + class KryoSerializerSuite extends FunSuite with SharedSparkContext { conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") conf.set("spark.kryo.registrator", classOf[MyRegistrator].getName) @@ -207,7 +208,7 @@ class KryoSerializerSuite extends FunSuite with SharedSparkContext { .fold(new ClassWithoutNoArgConstructor(10))((t1, t2) => new ClassWithoutNoArgConstructor(t1.x + t2.x)).x assert(10 + control.sum === result) } - + test("kryo with nonexistent custom registrator should fail") { import org.apache.spark.{SparkConf, SparkException} @@ -238,39 +239,12 @@ class KryoSerializerSuite extends FunSuite with SharedSparkContext { } } -class ClassLoaderTestingObject - -class KryoSerializerResizableOutputSuite extends FunSuite { - import org.apache.spark.SparkConf - import org.apache.spark.SparkContext - import org.apache.spark.LocalSparkContext - import org.apache.spark.SparkException - - // trial and error showed this will not serialize with 1mb buffer - val x = (1 to 400000).toArray - test("kryo without resizable output buffer should fail on large array") { - val conf = new SparkConf(false) - conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") - conf.set("spark.kryoserializer.buffer.mb", "1") - conf.set("spark.kryoserializer.buffer.max.mb", "1") - val sc = new SparkContext("local", "test", conf) - intercept[SparkException](sc.parallelize(x).collect) - LocalSparkContext.stop(sc) - } +class ClassLoaderTestingObject - test("kryo with resizable output buffer should succeed on large array") { - val conf = new SparkConf(false) - conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") - conf.set("spark.kryoserializer.buffer.mb", "1") - conf.set("spark.kryoserializer.buffer.max.mb", "2") - val sc = new SparkContext("local", "test", conf) - assert(sc.parallelize(x).collect === x) - LocalSparkContext.stop(sc) - } -} object KryoTest { + case class CaseClass(i: Int, s: String) {} class ClassWithNoArgConstructor { diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 1e3c760b845de..bbe68b29d2d8e 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -61,6 +61,17 @@ object MimaExcludes { ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.storage.MemoryStore.Entry") ) ++ + Seq( + // Serializer interface change. See SPARK-3045. + ProblemFilters.exclude[IncompatibleTemplateDefProblem]( + "org.apache.spark.serializer.DeserializationStream"), + ProblemFilters.exclude[IncompatibleTemplateDefProblem]( + "org.apache.spark.serializer.Serializer"), + ProblemFilters.exclude[IncompatibleTemplateDefProblem]( + "org.apache.spark.serializer.SerializationStream"), + ProblemFilters.exclude[IncompatibleTemplateDefProblem]( + "org.apache.spark.serializer.SerializerInstance") + )++ Seq( // Renamed putValues -> putArray + putIterator ProblemFilters.exclude[MissingMethodProblem]( From 0e0ec2eeb1eab1cb6dabbaa60d30242d0d7e292f Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sat, 16 Aug 2014 00:04:55 -0700 Subject: [PATCH 155/231] [SPARK-2977] Ensure ShuffleManager is created before ShuffleBlockManager This is intended to fix SPARK-2977. Before, there was an implicit ordering dependency where we needed to know the ShuffleManager implementation before creating the ShuffleBlockManager. This patch makes that dependency explicit by adding ShuffleManager to a bunch of constructors. I think it's a little odd for BlockManager to take a ShuffleManager only to pass it to ShuffleBlockManager without using it itself; there's an opportunity to clean this up later if we sever the circular dependencies between BlockManager and other components and pass those components to BlockManager's constructor. Author: Josh Rosen Closes #1976 from JoshRosen/SPARK-2977 and squashes the following commits: a9cd1e1 [Josh Rosen] [SPARK-2977] Ensure ShuffleManager is created before ShuffleBlockManager. (cherry picked from commit 20fcf3d0b72f3707dc1ed95d453f570fabdefd16) Signed-off-by: Josh Rosen --- .../scala/org/apache/spark/SparkEnv.scala | 22 +++++++++---------- .../apache/spark/storage/BlockManager.scala | 11 ++++++---- .../spark/storage/ShuffleBlockManager.scala | 7 +++--- .../apache/spark/storage/ThreadingTest.scala | 3 ++- .../spark/storage/BlockManagerSuite.scala | 12 +++++----- .../spark/storage/DiskBlockManagerSuite.scala | 8 +++++-- 6 files changed, 37 insertions(+), 26 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 22d8d1cb1ddcf..fc36e37c53f5e 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -210,12 +210,22 @@ object SparkEnv extends Logging { "MapOutputTracker", new MapOutputTrackerMasterActor(mapOutputTracker.asInstanceOf[MapOutputTrackerMaster], conf)) + // Let the user specify short names for shuffle managers + val shortShuffleMgrNames = Map( + "hash" -> "org.apache.spark.shuffle.hash.HashShuffleManager", + "sort" -> "org.apache.spark.shuffle.sort.SortShuffleManager") + val shuffleMgrName = conf.get("spark.shuffle.manager", "hash") + val shuffleMgrClass = shortShuffleMgrNames.getOrElse(shuffleMgrName.toLowerCase, shuffleMgrName) + val shuffleManager = instantiateClass[ShuffleManager](shuffleMgrClass) + + val shuffleMemoryManager = new ShuffleMemoryManager(conf) + val blockManagerMaster = new BlockManagerMaster(registerOrLookup( "BlockManagerMaster", new BlockManagerMasterActor(isLocal, conf, listenerBus)), conf) val blockManager = new BlockManager(executorId, actorSystem, blockManagerMaster, - serializer, conf, securityManager, mapOutputTracker) + serializer, conf, securityManager, mapOutputTracker, shuffleManager) val connectionManager = blockManager.connectionManager @@ -250,16 +260,6 @@ object SparkEnv extends Logging { "." } - // Let the user specify short names for shuffle managers - val shortShuffleMgrNames = Map( - "hash" -> "org.apache.spark.shuffle.hash.HashShuffleManager", - "sort" -> "org.apache.spark.shuffle.sort.SortShuffleManager") - val shuffleMgrName = conf.get("spark.shuffle.manager", "hash") - val shuffleMgrClass = shortShuffleMgrNames.getOrElse(shuffleMgrName.toLowerCase, shuffleMgrName) - val shuffleManager = instantiateClass[ShuffleManager](shuffleMgrClass) - - val shuffleMemoryManager = new ShuffleMemoryManager(conf) - // Warn about deprecated spark.cache.class property if (conf.contains("spark.cache.class")) { logWarning("The spark.cache.class property is no longer being used! Specify storage " + diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index e8bbd298c631a..e4c3d58905e7f 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -33,6 +33,7 @@ import org.apache.spark.executor._ import org.apache.spark.io.CompressionCodec import org.apache.spark.network._ import org.apache.spark.serializer.Serializer +import org.apache.spark.shuffle.ShuffleManager import org.apache.spark.util._ private[spark] sealed trait BlockValues @@ -57,11 +58,12 @@ private[spark] class BlockManager( maxMemory: Long, val conf: SparkConf, securityManager: SecurityManager, - mapOutputTracker: MapOutputTracker) + mapOutputTracker: MapOutputTracker, + shuffleManager: ShuffleManager) extends Logging { private val port = conf.getInt("spark.blockManager.port", 0) - val shuffleBlockManager = new ShuffleBlockManager(this) + val shuffleBlockManager = new ShuffleBlockManager(this, shuffleManager) val diskBlockManager = new DiskBlockManager(shuffleBlockManager, conf.get("spark.local.dir", System.getProperty("java.io.tmpdir"))) val connectionManager = @@ -142,9 +144,10 @@ private[spark] class BlockManager( serializer: Serializer, conf: SparkConf, securityManager: SecurityManager, - mapOutputTracker: MapOutputTracker) = { + mapOutputTracker: MapOutputTracker, + shuffleManager: ShuffleManager) = { this(execId, actorSystem, master, serializer, BlockManager.getMaxMemory(conf), - conf, securityManager, mapOutputTracker) + conf, securityManager, mapOutputTracker, shuffleManager) } /** diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala index 3565719b54545..b8f5d3a5b02aa 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala @@ -25,6 +25,7 @@ import scala.collection.JavaConversions._ import org.apache.spark.Logging import org.apache.spark.serializer.Serializer +import org.apache.spark.shuffle.ShuffleManager import org.apache.spark.storage.ShuffleBlockManager.ShuffleFileGroup import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashMap} import org.apache.spark.util.collection.{PrimitiveKeyOpenHashMap, PrimitiveVector} @@ -62,7 +63,8 @@ private[spark] trait ShuffleWriterGroup { */ // TODO: Factor this into a separate class for each ShuffleManager implementation private[spark] -class ShuffleBlockManager(blockManager: BlockManager) extends Logging { +class ShuffleBlockManager(blockManager: BlockManager, + shuffleManager: ShuffleManager) extends Logging { def conf = blockManager.conf // Turning off shuffle file consolidation causes all shuffle Blocks to get their own file. @@ -71,8 +73,7 @@ class ShuffleBlockManager(blockManager: BlockManager) extends Logging { conf.getBoolean("spark.shuffle.consolidateFiles", false) // Are we using sort-based shuffle? - val sortBasedShuffle = - conf.get("spark.shuffle.manager", "") == classOf[SortShuffleManager].getName + val sortBasedShuffle = shuffleManager.isInstanceOf[SortShuffleManager] private val bufferSize = conf.getInt("spark.shuffle.file.buffer.kb", 32) * 1024 diff --git a/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala b/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala index 75c2e09a6bbb8..aa83ea90ee9ee 100644 --- a/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala +++ b/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala @@ -20,6 +20,7 @@ package org.apache.spark.storage import java.util.concurrent.ArrayBlockingQueue import akka.actor._ +import org.apache.spark.shuffle.hash.HashShuffleManager import util.Random import org.apache.spark.{MapOutputTrackerMaster, SecurityManager, SparkConf} @@ -101,7 +102,7 @@ private[spark] object ThreadingTest { conf) val blockManager = new BlockManager( "", actorSystem, blockManagerMaster, serializer, 1024 * 1024, conf, - new SecurityManager(conf), new MapOutputTrackerMaster(conf)) + new SecurityManager(conf), new MapOutputTrackerMaster(conf), new HashShuffleManager(conf)) val producers = (1 to numProducers).map(i => new ProducerThread(blockManager, i)) val consumers = producers.map(p => new ConsumerThread(blockManager, p.queue)) producers.foreach(_.start) diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 94bb2c445d2e9..20bac66105a69 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -24,6 +24,7 @@ import java.util.concurrent.TimeUnit import akka.actor._ import akka.pattern.ask import akka.util.Timeout +import org.apache.spark.shuffle.hash.HashShuffleManager import org.mockito.invocation.InvocationOnMock import org.mockito.Matchers.any @@ -61,6 +62,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter conf.set("spark.authenticate", "false") val securityMgr = new SecurityManager(conf) val mapOutputTracker = new MapOutputTrackerMaster(conf) + val shuffleManager = new HashShuffleManager(conf) // Reuse a serializer across tests to avoid creating a new thread-local buffer on each test conf.set("spark.kryoserializer.buffer.mb", "1") @@ -71,8 +73,8 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter def rdd(rddId: Int, splitId: Int) = RDDBlockId(rddId, splitId) private def makeBlockManager(maxMem: Long, name: String = ""): BlockManager = { - new BlockManager( - name, actorSystem, master, serializer, maxMem, conf, securityMgr, mapOutputTracker) + new BlockManager(name, actorSystem, master, serializer, maxMem, conf, securityMgr, + mapOutputTracker, shuffleManager) } before { @@ -791,7 +793,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter test("block store put failure") { // Use Java serializer so we can create an unserializable error. store = new BlockManager("", actorSystem, master, new JavaSerializer(conf), 1200, conf, - securityMgr, mapOutputTracker) + securityMgr, mapOutputTracker, shuffleManager) // The put should fail since a1 is not serializable. class UnserializableClass @@ -1007,7 +1009,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter test("return error message when error occurred in BlockManagerWorker#onBlockMessageReceive") { store = new BlockManager("", actorSystem, master, serializer, 1200, conf, - securityMgr, mapOutputTracker) + securityMgr, mapOutputTracker, shuffleManager) val worker = spy(new BlockManagerWorker(store)) val connManagerId = mock(classOf[ConnectionManagerId]) @@ -1054,7 +1056,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter test("return ack message when no error occurred in BlocManagerWorker#onBlockMessageReceive") { store = new BlockManager("", actorSystem, master, serializer, 1200, conf, - securityMgr, mapOutputTracker) + securityMgr, mapOutputTracker, shuffleManager) val worker = spy(new BlockManagerWorker(store)) val connManagerId = mock(classOf[ConnectionManagerId]) diff --git a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala index b8299e2ea187f..777579bc570db 100644 --- a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala @@ -19,6 +19,8 @@ package org.apache.spark.storage import java.io.{File, FileWriter} +import org.apache.spark.shuffle.hash.HashShuffleManager + import scala.collection.mutable import scala.language.reflectiveCalls @@ -42,7 +44,9 @@ class DiskBlockManagerSuite extends FunSuite with BeforeAndAfterEach with Before // so we coerce consolidation if not already enabled. testConf.set("spark.shuffle.consolidateFiles", "true") - val shuffleBlockManager = new ShuffleBlockManager(null) { + private val shuffleManager = new HashShuffleManager(testConf.clone) + + val shuffleBlockManager = new ShuffleBlockManager(null, shuffleManager) { override def conf = testConf.clone var idToSegmentMap = mutable.Map[ShuffleBlockId, FileSegment]() override def getBlockLocation(id: ShuffleBlockId) = idToSegmentMap(id) @@ -148,7 +152,7 @@ class DiskBlockManagerSuite extends FunSuite with BeforeAndAfterEach with Before actorSystem.actorOf(Props(new BlockManagerMasterActor(true, confCopy, new LiveListenerBus))), confCopy) val store = new BlockManager("", actorSystem, master , serializer, confCopy, - securityManager, null) + securityManager, null, shuffleManager) try { From 8c79574462eed113fc59d4323eedfc55c6e95c06 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Sat, 16 Aug 2014 11:26:51 -0700 Subject: [PATCH 156/231] [SQL] Using safe floating-point numbers in doctest Test code in `sql.py` tries to compare two floating-point numbers directly, and cased [build failure(s)](https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder/18365/consoleFull). [Doctest documentation](https://docs.python.org/3/library/doctest.html#warnings) recommends using numbers in the form of `I/2**J` to avoid the precision issue. Author: Cheng Lian Closes #1925 from liancheng/fix-pysql-fp-test and squashes the following commits: 0fbf584 [Cheng Lian] Removed unnecessary `...' from inferSchema doctest e8059d4 [Cheng Lian] Using safe floating-point numbers in doctest (cherry picked from commit b4a05928e95c0f6973fd21e60ff9c108f226e38c) Signed-off-by: Michael Armbrust --- python/pyspark/sql.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 95086a2258222..d4ca0cc8f336e 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -1093,8 +1093,8 @@ def applySchema(self, rdd, schema): >>> sqlCtx.sql( ... "SELECT byte1 - 1 AS byte1, byte2 + 1 AS byte2, " + ... "short1 + 1 AS short1, short2 - 1 AS short2, int - 1 AS int, " + - ... "float + 1.1 as float FROM table2").collect() - [Row(byte1=126, byte2=-127, short1=-32767, short2=32766, int=2147483646, float=2.1...)] + ... "float + 1.5 as float FROM table2").collect() + [Row(byte1=126, byte2=-127, short1=-32767, short2=32766, int=2147483646, float=2.5)] >>> rdd = sc.parallelize([(127, -32768, 1.0, ... datetime(2010, 1, 1, 1, 1, 1), From bd3ce2ffb8964abb4d59918ebb2c230fe4614aa2 Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Sat, 16 Aug 2014 14:15:58 -0700 Subject: [PATCH 157/231] [SPARK-2677] BasicBlockFetchIterator#next can wait forever Author: Kousuke Saruta Closes #1632 from sarutak/SPARK-2677 and squashes the following commits: cddbc7b [Kousuke Saruta] Removed Exception throwing when ConnectionManager#handleMessage receives ack for non-referenced message d3bd2a8 [Kousuke Saruta] Modified configuration.md for spark.core.connection.ack.timeout e85f88b [Kousuke Saruta] Removed useless synchronized blocks 7ed48be [Kousuke Saruta] Modified ConnectionManager to use ackTimeoutMonitor ConnectionManager-wide 9b620a6 [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into SPARK-2677 0dd9ad3 [Kousuke Saruta] Modified typo in ConnectionManagerSuite.scala 7cbb8ca [Kousuke Saruta] Modified to match with scalastyle 8a73974 [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into SPARK-2677 ade279a [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into SPARK-2677 0174d6a [Kousuke Saruta] Modified ConnectionManager.scala to handle the case remote Executor cannot ack a454239 [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into SPARK-2677 9b7b7c1 [Kousuke Saruta] (WIP) Modifying ConnectionManager.scala (cherry picked from commit 76fa0eaf515fd6771cdd69422b1259485debcae5) Signed-off-by: Josh Rosen --- .../spark/network/ConnectionManager.scala | 45 ++++++++++++++----- .../network/ConnectionManagerSuite.scala | 44 +++++++++++++++++- docs/configuration.md | 9 ++++ 3 files changed, 87 insertions(+), 11 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala index 95f96b8463a01..37d69a9ec4ce4 100644 --- a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala +++ b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala @@ -22,6 +22,7 @@ import java.nio._ import java.nio.channels._ import java.nio.channels.spi._ import java.net._ +import java.util.{Timer, TimerTask} import java.util.concurrent.atomic.AtomicInteger import java.util.concurrent.{LinkedBlockingDeque, TimeUnit, ThreadPoolExecutor} @@ -61,17 +62,17 @@ private[spark] class ConnectionManager( var ackMessage: Option[Message] = None def markDone(ackMessage: Option[Message]) { - this.synchronized { - this.ackMessage = ackMessage - completionHandler(this) - } + this.ackMessage = ackMessage + completionHandler(this) } } private val selector = SelectorProvider.provider.openSelector() + private val ackTimeoutMonitor = new Timer("AckTimeoutMonitor", true) // default to 30 second timeout waiting for authentication private val authTimeout = conf.getInt("spark.core.connection.auth.wait.timeout", 30) + private val ackTimeout = conf.getInt("spark.core.connection.ack.wait.timeout", 60) private val handleMessageExecutor = new ThreadPoolExecutor( conf.getInt("spark.core.connection.handler.threads.min", 20), @@ -652,19 +653,27 @@ private[spark] class ConnectionManager( } } if (bufferMessage.hasAckId()) { - val sentMessageStatus = messageStatuses.synchronized { + messageStatuses.synchronized { messageStatuses.get(bufferMessage.ackId) match { case Some(status) => { messageStatuses -= bufferMessage.ackId - status + status.markDone(Some(message)) } case None => { - throw new Exception("Could not find reference for received ack message " + - message.id) + /** + * We can fall down on this code because of following 2 cases + * + * (1) Invalid ack sent due to buggy code. + * + * (2) Late-arriving ack for a SendMessageStatus + * To avoid unwilling late-arriving ack + * caused by long pause like GC, you can set + * larger value than default to spark.core.connection.ack.wait.timeout + */ + logWarning(s"Could not find reference for received ack Message ${message.id}") } } } - sentMessageStatus.markDone(Some(message)) } else { var ackMessage : Option[Message] = None try { @@ -836,9 +845,23 @@ private[spark] class ConnectionManager( def sendMessageReliably(connectionManagerId: ConnectionManagerId, message: Message) : Future[Message] = { val promise = Promise[Message]() + + val timeoutTask = new TimerTask { + override def run(): Unit = { + messageStatuses.synchronized { + messageStatuses.remove(message.id).foreach ( s => { + promise.failure( + new IOException(s"sendMessageReliably failed because ack " + + "was not received within ${ackTimeout} sec")) + }) + } + } + } + val status = new MessageStatus(message, connectionManagerId, s => { + timeoutTask.cancel() s.ackMessage match { - case None => // Indicates a failure where we either never sent or never got ACK'd + case None => // Indicates a failure where we either never sent or never got ACK'd promise.failure(new IOException("sendMessageReliably failed without being ACK'd")) case Some(ackMessage) => if (ackMessage.hasError) { @@ -852,6 +875,8 @@ private[spark] class ConnectionManager( messageStatuses.synchronized { messageStatuses += ((message.id, status)) } + + ackTimeoutMonitor.schedule(timeoutTask, ackTimeout * 1000) sendMessage(connectionManagerId, message) promise.future } diff --git a/core/src/test/scala/org/apache/spark/network/ConnectionManagerSuite.scala b/core/src/test/scala/org/apache/spark/network/ConnectionManagerSuite.scala index 846537df003df..e2f4d4c57cdb5 100644 --- a/core/src/test/scala/org/apache/spark/network/ConnectionManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/network/ConnectionManagerSuite.scala @@ -19,14 +19,19 @@ package org.apache.spark.network import java.io.IOException import java.nio._ +import java.util.concurrent.TimeoutException import org.apache.spark.{SecurityManager, SparkConf} import org.scalatest.FunSuite +import org.mockito.Mockito._ +import org.mockito.Matchers._ + +import scala.concurrent.TimeoutException import scala.concurrent.{Await, TimeoutException} import scala.concurrent.duration._ import scala.language.postfixOps -import scala.util.Try +import scala.util.{Failure, Success, Try} /** * Test the ConnectionManager with various security settings. @@ -255,5 +260,42 @@ class ConnectionManagerSuite extends FunSuite { } + test("sendMessageReliably timeout") { + val clientConf = new SparkConf + clientConf.set("spark.authenticate", "false") + val ackTimeout = 30 + clientConf.set("spark.core.connection.ack.wait.timeout", s"${ackTimeout}") + + val clientSecurityManager = new SecurityManager(clientConf) + val manager = new ConnectionManager(0, clientConf, clientSecurityManager) + + val serverConf = new SparkConf + serverConf.set("spark.authenticate", "false") + val serverSecurityManager = new SecurityManager(serverConf) + val managerServer = new ConnectionManager(0, serverConf, serverSecurityManager) + managerServer.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { + // sleep 60 sec > ack timeout for simulating server slow down or hang up + Thread.sleep(ackTimeout * 3 * 1000) + None + }) + + val size = 10 * 1024 * 1024 + val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) + buffer.flip + val bufferMessage = Message.createBufferMessage(buffer.duplicate) + + val future = manager.sendMessageReliably(managerServer.id, bufferMessage) + + // Future should throw IOException in 30 sec. + // Otherwise TimeoutExcepton is thrown from Await.result. + // We expect TimeoutException is not thrown. + intercept[IOException] { + Await.result(future, (ackTimeout * 2) second) + } + + manager.stop() + managerServer.stop() + } + } diff --git a/docs/configuration.md b/docs/configuration.md index c408c468dcd94..981170d8b49b7 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -884,6 +884,15 @@ Apart from these, the following properties are also available, and may be useful out and giving up. + + spark.core.connection.ack.wait.timeout + 60 + + Number of seconds for the connection to wait for ack to occur before timing + out and giving up. To avoid unwilling timeout caused by long pause like GC, + you can set larger value. + + spark.ui.filters None From 0b354be2f9ec35547a60591acf4f4773a4869690 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Sat, 16 Aug 2014 15:13:34 -0700 Subject: [PATCH 158/231] [SPARK-3048][MLLIB] add LabeledPoint.parse and remove loadStreamingLabeledPoints Move `parse()` from `LabeledPointParser` to `LabeledPoint` and make it public. This breaks binary compatibility only when a user uses synthesized methods like `tupled` and `curried`, which is rare. `LabeledPoint.parse` is more consistent with `Vectors.parse`, which is why `LabeledPointParser` is not preferred. freeman-lab tdas Author: Xiangrui Meng Closes #1952 from mengxr/labelparser and squashes the following commits: c818fb2 [Xiangrui Meng] merge master ce20e6f [Xiangrui Meng] update mima excludes b386b8d [Xiangrui Meng] fix tests 2436b3d [Xiangrui Meng] add parse() to LabeledPoint (cherry picked from commit 7e70708a99949549adde00cb6246a9582bbc4929) Signed-off-by: Xiangrui Meng --- .../mllib/StreamingLinearRegression.scala | 7 +++---- .../spark/mllib/regression/LabeledPoint.scala | 2 +- .../StreamingLinearRegressionWithSGD.scala | 2 +- .../org/apache/spark/mllib/util/MLUtils.scala | 17 ++--------------- .../mllib/regression/LabeledPointSuite.scala | 4 ++-- .../StreamingLinearRegressionSuite.scala | 6 +++--- project/MimaExcludes.scala | 5 +++++ 7 files changed, 17 insertions(+), 26 deletions(-) diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLinearRegression.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLinearRegression.scala index 1fd37edfa7427..0e992fa9967bb 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLinearRegression.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLinearRegression.scala @@ -18,8 +18,7 @@ package org.apache.spark.examples.mllib import org.apache.spark.mllib.linalg.Vectors -import org.apache.spark.mllib.util.MLUtils -import org.apache.spark.mllib.regression.StreamingLinearRegressionWithSGD +import org.apache.spark.mllib.regression.{LabeledPoint, StreamingLinearRegressionWithSGD} import org.apache.spark.SparkConf import org.apache.spark.streaming.{Seconds, StreamingContext} @@ -56,8 +55,8 @@ object StreamingLinearRegression { val conf = new SparkConf().setMaster("local").setAppName("StreamingLinearRegression") val ssc = new StreamingContext(conf, Seconds(args(2).toLong)) - val trainingData = MLUtils.loadStreamingLabeledPoints(ssc, args(0)) - val testData = MLUtils.loadStreamingLabeledPoints(ssc, args(1)) + val trainingData = ssc.textFileStream(args(0)).map(LabeledPoint.parse) + val testData = ssc.textFileStream(args(1)).map(LabeledPoint.parse) val model = new StreamingLinearRegressionWithSGD() .setInitialWeights(Vectors.dense(Array.fill[Double](args(3).toInt)(0))) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala index 62a03af4a9964..17c753c56681f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala @@ -36,7 +36,7 @@ case class LabeledPoint(label: Double, features: Vector) { /** * Parser for [[org.apache.spark.mllib.regression.LabeledPoint]]. */ -private[mllib] object LabeledPointParser { +object LabeledPoint { /** * Parses a string resulted from `LabeledPoint#toString` into * an [[org.apache.spark.mllib.regression.LabeledPoint]]. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala index 8851097050318..1d11fde24712c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala @@ -18,7 +18,7 @@ package org.apache.spark.mllib.regression import org.apache.spark.annotation.Experimental -import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.linalg.Vector /** * Train or predict a linear regression model on streaming data. Training uses diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala index f4cce86a65ba7..ca35100aa99c6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala @@ -27,7 +27,7 @@ import org.apache.spark.SparkContext import org.apache.spark.rdd.RDD import org.apache.spark.rdd.PartitionwiseSampledRDD import org.apache.spark.util.random.BernoulliSampler -import org.apache.spark.mllib.regression.{LabeledPointParser, LabeledPoint} +import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.StreamingContext @@ -185,7 +185,7 @@ object MLUtils { * @return labeled points stored as an RDD[LabeledPoint] */ def loadLabeledPoints(sc: SparkContext, path: String, minPartitions: Int): RDD[LabeledPoint] = - sc.textFile(path, minPartitions).map(LabeledPointParser.parse) + sc.textFile(path, minPartitions).map(LabeledPoint.parse) /** * Loads labeled points saved using `RDD[LabeledPoint].saveAsTextFile` with the default number of @@ -194,19 +194,6 @@ object MLUtils { def loadLabeledPoints(sc: SparkContext, dir: String): RDD[LabeledPoint] = loadLabeledPoints(sc, dir, sc.defaultMinPartitions) - /** - * Loads streaming labeled points from a stream of text files - * where points are in the same format as used in `RDD[LabeledPoint].saveAsTextFile`. - * See `StreamingContext.textFileStream` for more details on how to - * generate a stream from files - * - * @param ssc Streaming context - * @param dir Directory path in any Hadoop-supported file system URI - * @return Labeled points stored as a DStream[LabeledPoint] - */ - def loadStreamingLabeledPoints(ssc: StreamingContext, dir: String): DStream[LabeledPoint] = - ssc.textFileStream(dir).map(LabeledPointParser.parse) - /** * Load labeled data from a file. The data format used here is * , ... diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LabeledPointSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LabeledPointSuite.scala index d9308aaba6ee1..110c44a7193fd 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LabeledPointSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LabeledPointSuite.scala @@ -28,12 +28,12 @@ class LabeledPointSuite extends FunSuite { LabeledPoint(1.0, Vectors.dense(1.0, 0.0)), LabeledPoint(0.0, Vectors.sparse(2, Array(1), Array(-1.0)))) points.foreach { p => - assert(p === LabeledPointParser.parse(p.toString)) + assert(p === LabeledPoint.parse(p.toString)) } } test("parse labeled points with v0.9 format") { - val point = LabeledPointParser.parse("1.0,1.0 0.0 -2.0") + val point = LabeledPoint.parse("1.0,1.0 0.0 -2.0") assert(point === LabeledPoint(1.0, Vectors.dense(1.0, 0.0, -2.0))) } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala index ed21f84472c9a..45e25eecf508e 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala @@ -26,7 +26,7 @@ import com.google.common.io.Files import org.scalatest.FunSuite import org.apache.spark.mllib.linalg.Vectors -import org.apache.spark.mllib.util.{LinearDataGenerator, LocalSparkContext, MLUtils} +import org.apache.spark.mllib.util.{LinearDataGenerator, LocalSparkContext} import org.apache.spark.streaming.{Milliseconds, StreamingContext} import org.apache.spark.util.Utils @@ -55,7 +55,7 @@ class StreamingLinearRegressionSuite extends FunSuite with LocalSparkContext { val numBatches = 10 val batchDuration = Milliseconds(1000) val ssc = new StreamingContext(sc, batchDuration) - val data = MLUtils.loadStreamingLabeledPoints(ssc, testDir.toString) + val data = ssc.textFileStream(testDir.toString).map(LabeledPoint.parse) val model = new StreamingLinearRegressionWithSGD() .setInitialWeights(Vectors.dense(0.0, 0.0)) .setStepSize(0.1) @@ -97,7 +97,7 @@ class StreamingLinearRegressionSuite extends FunSuite with LocalSparkContext { val batchDuration = Milliseconds(2000) val ssc = new StreamingContext(sc, batchDuration) val numBatches = 5 - val data = MLUtils.loadStreamingLabeledPoints(ssc, testDir.toString) + val data = ssc.textFileStream(testDir.toString()).map(LabeledPoint.parse) val model = new StreamingLinearRegressionWithSGD() .setInitialWeights(Vectors.dense(0.0)) .setStepSize(0.1) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index bbe68b29d2d8e..300589394b96f 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -129,6 +129,11 @@ object MimaExcludes { Seq( // new Vector methods in MLlib (binary compatible assuming users do not implement Vector) ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.mllib.linalg.Vector.copy") ) ++ + Seq( // synthetic methods generated in LabeledPoint + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.mllib.regression.LabeledPoint$"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.mllib.regression.LabeledPoint.apply"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.mllib.regression.LabeledPoint.toString") + ) ++ Seq ( // Scala 2.11 compatibility fix ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.StreamingContext.$default$2") ) From a12d3ae3223535e6e4c774e4a289b8b2f2e5228b Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Sat, 16 Aug 2014 15:14:43 -0700 Subject: [PATCH 159/231] [SPARK-3081][MLLIB] rename RandomRDDGenerators to RandomRDDs `RandomRDDGenerators` means factory for `RandomRDDGenerator`. However, its methods return RDDs but not RDDGenerators. So a more proper (and shorter) name would be `RandomRDDs`. dorx brkyvz Author: Xiangrui Meng Closes #1979 from mengxr/randomrdds and squashes the following commits: b161a2d [Xiangrui Meng] rename RandomRDDGenerators to RandomRDDs (cherry picked from commit ac6411c6e75906997c78de23dfdbc8d225b87cfd) Signed-off-by: Xiangrui Meng --- .../mllib/api/python/PythonMLLibAPI.scala | 2 +- ...omRDDGenerators.scala => RandomRDDs.scala} | 6 ++--- ...atorsSuite.scala => RandomRDDsSuite.scala} | 16 ++++++------ python/pyspark/mllib/random.py | 25 +++++++++---------- 4 files changed, 24 insertions(+), 25 deletions(-) rename mllib/src/main/scala/org/apache/spark/mllib/random/{RandomRDDGenerators.scala => RandomRDDs.scala} (99%) rename mllib/src/test/scala/org/apache/spark/mllib/random/{RandomRDDGeneratorsSuite.scala => RandomRDDsSuite.scala} (88%) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 18dc087856785..4343124f102a0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -27,7 +27,7 @@ import org.apache.spark.mllib.classification._ import org.apache.spark.mllib.clustering._ import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.linalg.{Matrix, SparseVector, Vector, Vectors} -import org.apache.spark.mllib.random.{RandomRDDGenerators => RG} +import org.apache.spark.mllib.random.{RandomRDDs => RG} import org.apache.spark.mllib.recommendation._ import org.apache.spark.mllib.regression._ import org.apache.spark.mllib.tree.configuration.{Algo, Strategy} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDGenerators.scala b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala similarity index 99% rename from mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDGenerators.scala rename to mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala index b0a0593223910..36270369526cd 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDGenerators.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala @@ -17,6 +17,8 @@ package org.apache.spark.mllib.random +import scala.reflect.ClassTag + import org.apache.spark.SparkContext import org.apache.spark.annotation.Experimental import org.apache.spark.mllib.linalg.Vector @@ -24,14 +26,12 @@ import org.apache.spark.mllib.rdd.{RandomVectorRDD, RandomRDD} import org.apache.spark.rdd.RDD import org.apache.spark.util.Utils -import scala.reflect.ClassTag - /** * :: Experimental :: * Generator methods for creating RDDs comprised of i.i.d. samples from some distribution. */ @Experimental -object RandomRDDGenerators { +object RandomRDDs { /** * :: Experimental :: diff --git a/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDGeneratorsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDsSuite.scala similarity index 88% rename from mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDGeneratorsSuite.scala rename to mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDsSuite.scala index 96e0bc63b0fa4..c50b78bcbcc61 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDGeneratorsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDsSuite.scala @@ -34,7 +34,7 @@ import org.apache.spark.util.StatCounter * * TODO update tests to use TestingUtils for floating point comparison after PR 1367 is merged */ -class RandomRDDGeneratorsSuite extends FunSuite with LocalSparkContext with Serializable { +class RandomRDDsSuite extends FunSuite with LocalSparkContext with Serializable { def testGeneratedRDD(rdd: RDD[Double], expectedSize: Long, @@ -113,18 +113,18 @@ class RandomRDDGeneratorsSuite extends FunSuite with LocalSparkContext with Seri val poissonMean = 100.0 for (seed <- 0 until 5) { - val uniform = RandomRDDGenerators.uniformRDD(sc, size, numPartitions, seed) + val uniform = RandomRDDs.uniformRDD(sc, size, numPartitions, seed) testGeneratedRDD(uniform, size, numPartitions, 0.5, 1 / math.sqrt(12)) - val normal = RandomRDDGenerators.normalRDD(sc, size, numPartitions, seed) + val normal = RandomRDDs.normalRDD(sc, size, numPartitions, seed) testGeneratedRDD(normal, size, numPartitions, 0.0, 1.0) - val poisson = RandomRDDGenerators.poissonRDD(sc, poissonMean, size, numPartitions, seed) + val poisson = RandomRDDs.poissonRDD(sc, poissonMean, size, numPartitions, seed) testGeneratedRDD(poisson, size, numPartitions, poissonMean, math.sqrt(poissonMean), 0.1) } // mock distribution to check that partitions have unique seeds - val random = RandomRDDGenerators.randomRDD(sc, new MockDistro(), 1000L, 1000, 0L) + val random = RandomRDDs.randomRDD(sc, new MockDistro(), 1000L, 1000, 0L) assert(random.collect.size === random.collect.distinct.size) } @@ -135,13 +135,13 @@ class RandomRDDGeneratorsSuite extends FunSuite with LocalSparkContext with Seri val poissonMean = 100.0 for (seed <- 0 until 5) { - val uniform = RandomRDDGenerators.uniformVectorRDD(sc, rows, cols, parts, seed) + val uniform = RandomRDDs.uniformVectorRDD(sc, rows, cols, parts, seed) testGeneratedVectorRDD(uniform, rows, cols, parts, 0.5, 1 / math.sqrt(12)) - val normal = RandomRDDGenerators.normalVectorRDD(sc, rows, cols, parts, seed) + val normal = RandomRDDs.normalVectorRDD(sc, rows, cols, parts, seed) testGeneratedVectorRDD(normal, rows, cols, parts, 0.0, 1.0) - val poisson = RandomRDDGenerators.poissonVectorRDD(sc, poissonMean, rows, cols, parts, seed) + val poisson = RandomRDDs.poissonVectorRDD(sc, poissonMean, rows, cols, parts, seed) testGeneratedVectorRDD(poisson, rows, cols, parts, poissonMean, math.sqrt(poissonMean), 0.1) } } diff --git a/python/pyspark/mllib/random.py b/python/pyspark/mllib/random.py index eb496688b6eef..3f3b19053d32e 100644 --- a/python/pyspark/mllib/random.py +++ b/python/pyspark/mllib/random.py @@ -25,8 +25,7 @@ from pyspark.serializers import NoOpSerializer -class RandomRDDGenerators: - +class RandomRDDs: """ Generator methods for creating RDDs comprised of i.i.d samples from some distribution. @@ -40,17 +39,17 @@ def uniformRDD(sc, size, numPartitions=None, seed=None): To transform the distribution in the generated RDD from U[0.0, 1.0] to U[a, b], use - C{RandomRDDGenerators.uniformRDD(sc, n, p, seed)\ + C{RandomRDDs.uniformRDD(sc, n, p, seed)\ .map(lambda v: a + (b - a) * v)} - >>> x = RandomRDDGenerators.uniformRDD(sc, 100).collect() + >>> x = RandomRDDs.uniformRDD(sc, 100).collect() >>> len(x) 100 >>> max(x) <= 1.0 and min(x) >= 0.0 True - >>> RandomRDDGenerators.uniformRDD(sc, 100, 4).getNumPartitions() + >>> RandomRDDs.uniformRDD(sc, 100, 4).getNumPartitions() 4 - >>> parts = RandomRDDGenerators.uniformRDD(sc, 100, seed=4).getNumPartitions() + >>> parts = RandomRDDs.uniformRDD(sc, 100, seed=4).getNumPartitions() >>> parts == sc.defaultParallelism True """ @@ -66,10 +65,10 @@ def normalRDD(sc, size, numPartitions=None, seed=None): To transform the distribution in the generated RDD from standard normal to some other normal N(mean, sigma), use - C{RandomRDDGenerators.normal(sc, n, p, seed)\ + C{RandomRDDs.normal(sc, n, p, seed)\ .map(lambda v: mean + sigma * v)} - >>> x = RandomRDDGenerators.normalRDD(sc, 1000, seed=1L) + >>> x = RandomRDDs.normalRDD(sc, 1000, seed=1L) >>> stats = x.stats() >>> stats.count() 1000L @@ -89,7 +88,7 @@ def poissonRDD(sc, mean, size, numPartitions=None, seed=None): distribution with the input mean. >>> mean = 100.0 - >>> x = RandomRDDGenerators.poissonRDD(sc, mean, 1000, seed=1L) + >>> x = RandomRDDs.poissonRDD(sc, mean, 1000, seed=1L) >>> stats = x.stats() >>> stats.count() 1000L @@ -110,12 +109,12 @@ def uniformVectorRDD(sc, numRows, numCols, numPartitions=None, seed=None): from the uniform distribution on [0.0 1.0]. >>> import numpy as np - >>> mat = np.matrix(RandomRDDGenerators.uniformVectorRDD(sc, 10, 10).collect()) + >>> mat = np.matrix(RandomRDDs.uniformVectorRDD(sc, 10, 10).collect()) >>> mat.shape (10, 10) >>> mat.max() <= 1.0 and mat.min() >= 0.0 True - >>> RandomRDDGenerators.uniformVectorRDD(sc, 10, 10, 4).getNumPartitions() + >>> RandomRDDs.uniformVectorRDD(sc, 10, 10, 4).getNumPartitions() 4 """ jrdd = sc._jvm.PythonMLLibAPI() \ @@ -130,7 +129,7 @@ def normalVectorRDD(sc, numRows, numCols, numPartitions=None, seed=None): from the standard normal distribution. >>> import numpy as np - >>> mat = np.matrix(RandomRDDGenerators.normalVectorRDD(sc, 100, 100, seed=1L).collect()) + >>> mat = np.matrix(RandomRDDs.normalVectorRDD(sc, 100, 100, seed=1L).collect()) >>> mat.shape (100, 100) >>> abs(mat.mean() - 0.0) < 0.1 @@ -151,7 +150,7 @@ def poissonVectorRDD(sc, mean, numRows, numCols, numPartitions=None, seed=None): >>> import numpy as np >>> mean = 100.0 - >>> rdd = RandomRDDGenerators.poissonVectorRDD(sc, mean, 100, 100, seed=1L) + >>> rdd = RandomRDDs.poissonVectorRDD(sc, mean, 100, 100, seed=1L) >>> mat = np.mat(rdd.collect()) >>> mat.shape (100, 100) From 721f2fdc95032132af3d4a00dbc8399d356f8faf Mon Sep 17 00:00:00 2001 From: iAmGhost Date: Sat, 16 Aug 2014 16:48:38 -0700 Subject: [PATCH 160/231] [SPARK-3035] Wrong example with SparkContext.addFile https://issues.apache.org/jira/browse/SPARK-3035 fix for wrong document. Author: iAmGhost Closes #1942 from iAmGhost/master and squashes the following commits: 487528a [iAmGhost] [SPARK-3035] Wrong example with SparkContext.addFile fix for wrong document. (cherry picked from commit 379e7585c356f20bf8b4878ecba9401e2195da12) Signed-off-by: Josh Rosen --- python/pyspark/context.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 4001ecab5ea00..6c049238819a7 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -613,7 +613,7 @@ def addFile(self, path): >>> def func(iterator): ... with open(SparkFiles.get("test.txt")) as testFile: ... fileVal = int(testFile.readline()) - ... return [x * 100 for x in iterator] + ... return [x * fileVal for x in iterator] >>> sc.parallelize([1, 2, 3, 4]).mapPartitions(func).collect() [100, 200, 300, 400] """ From 5dd571c29ef97cadd23a54fcf4d5de869e3c56bc Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Sat, 16 Aug 2014 16:59:34 -0700 Subject: [PATCH 161/231] [SPARK-1065] [PySpark] improve supporting for large broadcast Passing large object by py4j is very slow (cost much memory), so pass broadcast objects via files (similar to parallelize()). Add an option to keep object in driver (it's False by default) to save memory in driver. Author: Davies Liu Closes #1912 from davies/broadcast and squashes the following commits: e06df4a [Davies Liu] load broadcast from disk in driver automatically db3f232 [Davies Liu] fix serialization of accumulator 631a827 [Davies Liu] Merge branch 'master' into broadcast c7baa8c [Davies Liu] compress serrialized broadcast and command 9a7161f [Davies Liu] fix doc tests e93cf4b [Davies Liu] address comments: add test 6226189 [Davies Liu] improve large broadcast (cherry picked from commit 2fc8aca086a2679b854038b7e2c488f19039ecbd) Signed-off-by: Josh Rosen --- .../apache/spark/api/python/PythonRDD.scala | 8 ++++ python/pyspark/broadcast.py | 37 ++++++++++++++----- python/pyspark/context.py | 20 ++++++---- python/pyspark/rdd.py | 5 ++- python/pyspark/serializers.py | 17 +++++++++ python/pyspark/tests.py | 7 ++++ python/pyspark/worker.py | 8 ++-- 7 files changed, 81 insertions(+), 21 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 9f5c5bd30f0c9..10210a2927dcc 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -315,6 +315,14 @@ private[spark] object PythonRDD extends Logging { JavaRDD.fromRDD(sc.sc.parallelize(objs, parallelism)) } + def readBroadcastFromFile(sc: JavaSparkContext, filename: String): Broadcast[Array[Byte]] = { + val file = new DataInputStream(new FileInputStream(filename)) + val length = file.readInt() + val obj = new Array[Byte](length) + file.readFully(obj) + sc.broadcast(obj) + } + def writeIteratorToStream[T](iter: Iterator[T], dataOut: DataOutputStream) { // The right way to implement this would be to use TypeTags to get the full // type of T. Since I don't want to introduce breaking changes throughout the diff --git a/python/pyspark/broadcast.py b/python/pyspark/broadcast.py index f3e64989ed564..675a2fcd2ff4e 100644 --- a/python/pyspark/broadcast.py +++ b/python/pyspark/broadcast.py @@ -21,18 +21,16 @@ >>> b = sc.broadcast([1, 2, 3, 4, 5]) >>> b.value [1, 2, 3, 4, 5] - ->>> from pyspark.broadcast import _broadcastRegistry ->>> _broadcastRegistry[b.bid] = b ->>> from cPickle import dumps, loads ->>> loads(dumps(b)).value -[1, 2, 3, 4, 5] - >>> sc.parallelize([0, 0]).flatMap(lambda x: b.value).collect() [1, 2, 3, 4, 5, 1, 2, 3, 4, 5] +>>> b.unpersist() >>> large_broadcast = sc.broadcast(list(range(10000))) """ +import os + +from pyspark.serializers import CompressedSerializer, PickleSerializer + # Holds broadcasted data received from Java, keyed by its id. _broadcastRegistry = {} @@ -52,17 +50,38 @@ class Broadcast(object): Access its value through C{.value}. """ - def __init__(self, bid, value, java_broadcast=None, pickle_registry=None): + def __init__(self, bid, value, java_broadcast=None, + pickle_registry=None, path=None): """ Should not be called directly by users -- use L{SparkContext.broadcast()} instead. """ - self.value = value self.bid = bid + if path is None: + self.value = value self._jbroadcast = java_broadcast self._pickle_registry = pickle_registry + self.path = path + + def unpersist(self, blocking=False): + self._jbroadcast.unpersist(blocking) + os.unlink(self.path) def __reduce__(self): self._pickle_registry.add(self) return (_from_id, (self.bid, )) + + def __getattr__(self, item): + if item == 'value' and self.path is not None: + ser = CompressedSerializer(PickleSerializer()) + value = ser.load_stream(open(self.path)).next() + self.value = value + return value + + raise AttributeError(item) + + +if __name__ == "__main__": + import doctest + doctest.testmod() diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 6c049238819a7..a90870ed3a353 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -29,7 +29,7 @@ from pyspark.files import SparkFiles from pyspark.java_gateway import launch_gateway from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deserializer, \ - PairDeserializer + PairDeserializer, CompressedSerializer from pyspark.storagelevel import StorageLevel from pyspark import rdd from pyspark.rdd import RDD @@ -566,13 +566,19 @@ def broadcast(self, value): """ Broadcast a read-only variable to the cluster, returning a L{Broadcast} - object for reading it in distributed functions. The variable will be - sent to each cluster only once. + object for reading it in distributed functions. The variable will + be sent to each cluster only once. + + :keep: Keep the `value` in driver or not. """ - pickleSer = PickleSerializer() - pickled = pickleSer.dumps(value) - jbroadcast = self._jsc.broadcast(bytearray(pickled)) - return Broadcast(jbroadcast.id(), value, jbroadcast, self._pickled_broadcast_vars) + ser = CompressedSerializer(PickleSerializer()) + # pass large object by py4j is very slow and need much memory + tempFile = NamedTemporaryFile(delete=False, dir=self._temp_dir) + ser.dump_stream([value], tempFile) + tempFile.close() + jbroadcast = self._jvm.PythonRDD.readBroadcastFromFile(self._jsc, tempFile.name) + return Broadcast(jbroadcast.id(), None, jbroadcast, + self._pickled_broadcast_vars, tempFile.name) def accumulator(self, value, accum_param=None): """ diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 3934bdda0a466..240381e5bae12 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -36,7 +36,7 @@ from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \ BatchedSerializer, CloudPickleSerializer, PairDeserializer, \ - PickleSerializer, pack_long + PickleSerializer, pack_long, CompressedSerializer from pyspark.join import python_join, python_left_outer_join, \ python_right_outer_join, python_cogroup from pyspark.statcounter import StatCounter @@ -1810,7 +1810,8 @@ def _jrdd(self): self._jrdd_deserializer = NoOpSerializer() command = (self.func, self._prev_jrdd_deserializer, self._jrdd_deserializer) - pickled_command = CloudPickleSerializer().dumps(command) + ser = CompressedSerializer(CloudPickleSerializer()) + pickled_command = ser.dumps(command) broadcast_vars = ListConverter().convert( [x._jbroadcast for x in self.ctx._pickled_broadcast_vars], self.ctx._gateway._gateway_client) diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index df90cafb245bf..74870c0edcf99 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -67,6 +67,7 @@ import sys import types import collections +import zlib from pyspark import cloudpickle @@ -403,6 +404,22 @@ def loads(self, obj): raise ValueError("invalid sevialization type: %s" % _type) +class CompressedSerializer(FramedSerializer): + """ + compress the serialized data + """ + + def __init__(self, serializer): + FramedSerializer.__init__(self) + self.serializer = serializer + + def dumps(self, obj): + return zlib.compress(self.serializer.dumps(obj), 1) + + def loads(self, obj): + return self.serializer.loads(zlib.decompress(obj)) + + class UTF8Deserializer(Serializer): """ diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 22b51110ed671..f1fece998cd54 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -323,6 +323,13 @@ def test_namedtuple_in_rdd(self): theDoes = self.sc.parallelize([jon, jane]) self.assertEquals([jon, jane], theDoes.collect()) + def test_large_broadcast(self): + N = 100000 + data = [[float(i) for i in range(300)] for i in range(N)] + bdata = self.sc.broadcast(data) # 270MB + m = self.sc.parallelize(range(1), 1).map(lambda x: len(bdata.value)).sum() + self.assertEquals(N, m) + class TestIO(PySparkTestCase): diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 2770f63059853..77a9c4a0e0677 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -30,7 +30,8 @@ from pyspark.cloudpickle import CloudPickler from pyspark.files import SparkFiles from pyspark.serializers import write_with_length, write_int, read_long, \ - write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer + write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, \ + CompressedSerializer pickleSer = PickleSerializer() @@ -65,12 +66,13 @@ def main(infile, outfile): # fetch names and values of broadcast variables num_broadcast_variables = read_int(infile) + ser = CompressedSerializer(pickleSer) for _ in range(num_broadcast_variables): bid = read_long(infile) - value = pickleSer._read_with_length(infile) + value = ser._read_with_length(infile) _broadcastRegistry[bid] = Broadcast(bid, value) - command = pickleSer._read_with_length(infile) + command = ser._read_with_length(infile) (func, deserializer, serializer) = command init_time = time.time() iterator = deserializer.load_stream(infile) From f02e327f0bc975e7f33092e449bc0edd95f95580 Mon Sep 17 00:00:00 2001 From: GuoQiang Li Date: Sat, 16 Aug 2014 20:05:55 -0700 Subject: [PATCH 162/231] In the stop method of ConnectionManager to cancel the ackTimeoutMonitor cc JoshRosen sarutak Author: GuoQiang Li Closes #1989 from witgo/cancel_ackTimeoutMonitor and squashes the following commits: 4a700fa [GuoQiang Li] In the stop method of ConnectionManager to cancel the ackTimeoutMonitor (cherry picked from commit bc95fe08dff62a0abea314ab4ab9275c8f119598) Signed-off-by: Josh Rosen --- .../main/scala/org/apache/spark/network/ConnectionManager.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala index 37d69a9ec4ce4..e77d762bdf221 100644 --- a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala +++ b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala @@ -886,6 +886,7 @@ private[spark] class ConnectionManager( } def stop() { + ackTimeoutMonitor.cancel() selectorThread.interrupt() selectorThread.join() selector.close() From 413a329e186de2ec96f80f614c36678bee6f332f Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Sat, 16 Aug 2014 21:16:27 -0700 Subject: [PATCH 163/231] [SPARK-3077][MLLIB] fix some chisq-test - promote nullHypothesis field in ChiSqTestResult to TestResult. Every test should have a null hypothesis - correct null hypothesis statement for independence test - p-value: 0.01 -> 0.1 Author: Xiangrui Meng Closes #1982 from mengxr/fix-chisq and squashes the following commits: 5f0de02 [Xiangrui Meng] make ChiSqTestResult constructor package private bc74ea1 [Xiangrui Meng] update chisq-test (cherry picked from commit fbad72288d8b6e641b00417a544cae6e8bfef2d7) Signed-off-by: Xiangrui Meng --- .../spark/mllib/stat/test/ChiSqTest.scala | 2 +- .../spark/mllib/stat/test/TestResult.scala | 28 +++++++++++-------- 2 files changed, 17 insertions(+), 13 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala index 8f6752737402e..215de95db5113 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala @@ -56,7 +56,7 @@ private[stat] object ChiSqTest extends Logging { object NullHypothesis extends Enumeration { type NullHypothesis = Value val goodnessOfFit = Value("observed follows the same distribution as expected.") - val independence = Value("observations in each column are statistically independent.") + val independence = Value("the occurrence of the outcomes is statistically independent.") } // Method identification based on input methodName string diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/TestResult.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/TestResult.scala index 2f278621335e1..4784f9e947908 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/TestResult.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/TestResult.scala @@ -44,6 +44,11 @@ trait TestResult[DF] { */ def statistic: Double + /** + * Null hypothesis of the test. + */ + def nullHypothesis: String + /** * String explaining the hypothesis test result. * Specific classes implementing this trait should override this method to output test-specific @@ -53,13 +58,13 @@ trait TestResult[DF] { // String explaining what the p-value indicates. val pValueExplain = if (pValue <= 0.01) { - "Very strong presumption against null hypothesis." + s"Very strong presumption against null hypothesis: $nullHypothesis." } else if (0.01 < pValue && pValue <= 0.05) { - "Strong presumption against null hypothesis." - } else if (0.05 < pValue && pValue <= 0.01) { - "Low presumption against null hypothesis." + s"Strong presumption against null hypothesis: $nullHypothesis." + } else if (0.05 < pValue && pValue <= 0.1) { + s"Low presumption against null hypothesis: $nullHypothesis." } else { - "No presumption against null hypothesis." + s"No presumption against null hypothesis: $nullHypothesis." } s"degrees of freedom = ${degreesOfFreedom.toString} \n" + @@ -70,19 +75,18 @@ trait TestResult[DF] { /** * :: Experimental :: - * Object containing the test results for the chi squared hypothesis test. + * Object containing the test results for the chi-squared hypothesis test. */ @Experimental -class ChiSqTestResult(override val pValue: Double, +class ChiSqTestResult private[stat] (override val pValue: Double, override val degreesOfFreedom: Int, override val statistic: Double, val method: String, - val nullHypothesis: String) extends TestResult[Int] { + override val nullHypothesis: String) extends TestResult[Int] { override def toString: String = { - "Chi squared test summary: \n" + - s"method: $method \n" + - s"null hypothesis: $nullHypothesis \n" + - super.toString + "Chi squared test summary:\n" + + s"method: $method\n" + + super.toString } } From 91af120b4391656cb8f7b2300202dc622c032c33 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Sat, 16 Aug 2014 23:53:14 -0700 Subject: [PATCH 164/231] [SPARK-3042] [mllib] DecisionTree Filter top-down instead of bottom-up DecisionTree needs to match each example to a node at each iteration. It currently does this with a set of filters very inefficiently: For each example, it examines each node at the current level and traces up to the root to see if that example should be handled by that node. Fix: Filter top-down using the partly built tree itself. Major changes: * Eliminated Filter class, findBinsForLevel() method. * Set up node parent links in main loop over levels in train(). * Added predictNodeIndex() for filtering top-down. * Added DTMetadata class Other changes: * Pre-compute set of unorderedFeatures. Notes for following expected PR based on [https://issues.apache.org/jira/browse/SPARK-3043]: * The unorderedFeatures set will next be stored in a metadata structure to simplify function calls (to store other items such as the data in strategy). I've done initial tests indicating that this speeds things up, but am only now running large-scale ones. CC: mengxr manishamde chouqin Any comments are welcome---thanks! Author: Joseph K. Bradley Closes #1975 from jkbradley/dt-opt2 and squashes the following commits: a0ed0da [Joseph K. Bradley] Renamed DTMetadata to DecisionTreeMetadata. Small doc updates. 3726d20 [Joseph K. Bradley] Small code improvements based on code review. ac0b9f8 [Joseph K. Bradley] Small updates based on code review. Main change: Now using << instead of math.pow. db0d773 [Joseph K. Bradley] scala style fix 6a38f48 [Joseph K. Bradley] Added DTMetadata class for cleaner code 931a3a7 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-opt2 797f68a [Joseph K. Bradley] Fixed DecisionTreeSuite bug for training second level. Needed to update treePointToNodeIndex with groupShift. f40381c [Joseph K. Bradley] Merge branch 'dt-opt1' into dt-opt2 5f2dec2 [Joseph K. Bradley] Fixed scalastyle issue in TreePoint 6b5651e [Joseph K. Bradley] Updates based on code review. 1 major change: persisting to memory + disk, not just memory. 2d2aaaf [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-opt1 26d10dd [Joseph K. Bradley] Removed tree/model/Filter.scala since no longer used. Removed debugging println calls in DecisionTree.scala. 356daba [Joseph K. Bradley] Merge branch 'dt-opt1' into dt-opt2 430d782 [Joseph K. Bradley] Added more debug info on binning error. Added some docs. d036089 [Joseph K. Bradley] Print timing info to logDebug. e66f1b1 [Joseph K. Bradley] TreePoint * Updated doc * Made some methods private 8464a6e [Joseph K. Bradley] Moved TimeTracker to tree/impl/ in its own file, and cleaned it up. Removed debugging println calls from DecisionTree. Made TreePoint extend Serialiable a87e08f [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-opt1 c1565a5 [Joseph K. Bradley] Small DecisionTree updates: * Simplification: Updated calculateGainForSplit to take aggregates for a single (feature, split) pair. * Internal doc: findAggForOrderedFeatureClassification b914f3b [Joseph K. Bradley] DecisionTree optimization: eliminated filters + small changes b2ed1f3 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-opt 0f676e2 [Joseph K. Bradley] Optimizations + Bug fix for DecisionTree 3211f02 [Joseph K. Bradley] Optimizing DecisionTree * Added TreePoint representation to avoid calling findBin multiple times. * (not working yet, but debugging) f61e9d2 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-timing bcf874a [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-timing 511ec85 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-timing a95bc22 [Joseph K. Bradley] timing for DecisionTree internals (cherry picked from commit 73ab7f141c205df277c6ac19252e590d6806c41f) Signed-off-by: Xiangrui Meng --- .../spark/mllib/tree/DecisionTree.scala | 878 ++++++++---------- .../tree/impl/DecisionTreeMetadata.scala | 101 ++ .../spark/mllib/tree/impl/TreePoint.scala | 30 +- .../apache/spark/mllib/tree/model/Bin.scala | 18 +- .../mllib/tree/model/DecisionTreeModel.scala | 2 +- .../spark/mllib/tree/model/Filter.scala | 28 - .../apache/spark/mllib/tree/model/Node.scala | 16 +- .../apache/spark/mllib/tree/model/Split.scala | 5 +- .../spark/mllib/tree/DecisionTreeSuite.scala | 167 ++-- 9 files changed, 615 insertions(+), 630 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala delete mode 100644 mllib/src/main/scala/org/apache/spark/mllib/tree/model/Filter.scala diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 2a3107a13e916..6b9a8f72c244e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -27,7 +27,7 @@ import org.apache.spark.mllib.tree.configuration.Strategy import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.FeatureType._ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ -import org.apache.spark.mllib.tree.impl.{TimeTracker, TreePoint} +import org.apache.spark.mllib.tree.impl.{DecisionTreeMetadata, TimeTracker, TreePoint} import org.apache.spark.mllib.tree.impurity.{Impurities, Impurity} import org.apache.spark.mllib.tree.model._ import org.apache.spark.rdd.RDD @@ -62,43 +62,38 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo timer.start("init") val retaggedInput = input.retag(classOf[LabeledPoint]) + val metadata = DecisionTreeMetadata.buildMetadata(retaggedInput, strategy) logDebug("algo = " + strategy.algo) // Find the splits and the corresponding bins (interval between the splits) using a sample // of the input data. timer.start("findSplitsBins") - val (splits, bins) = DecisionTree.findSplitsBins(retaggedInput, strategy) + val (splits, bins) = DecisionTree.findSplitsBins(retaggedInput, metadata) val numBins = bins(0).length timer.stop("findSplitsBins") logDebug("numBins = " + numBins) + // Bin feature values (TreePoint representation). // Cache input RDD for speedup during multiple passes. - val treeInput = TreePoint.convertToTreeRDD(retaggedInput, strategy, bins) + val treeInput = TreePoint.convertToTreeRDD(retaggedInput, bins, metadata) .persist(StorageLevel.MEMORY_AND_DISK) + val numFeatures = metadata.numFeatures // depth of the decision tree val maxDepth = strategy.maxDepth // the max number of nodes possible given the depth of the tree - val maxNumNodes = math.pow(2, maxDepth + 1).toInt - 1 - // Initialize an array to hold filters applied to points for each node. - val filters = new Array[List[Filter]](maxNumNodes) - // The filter at the top node is an empty list. - filters(0) = List() + val maxNumNodes = (2 << maxDepth) - 1 // Initialize an array to hold parent impurity calculations for each node. val parentImpurities = new Array[Double](maxNumNodes) // dummy value for top node (updated during first split calculation) val nodes = new Array[Node](maxNumNodes) - // num features - val numFeatures = treeInput.take(1)(0).binnedFeatures.size // Calculate level for single group construction // Max memory usage for aggregates val maxMemoryUsage = strategy.maxMemoryInMB * 1024 * 1024 logDebug("max memory usage for aggregates = " + maxMemoryUsage + " bytes.") - val numElementsPerNode = DecisionTree.getElementsPerNode(numFeatures, numBins, - strategy.numClassesForClassification, strategy.isMulticlassWithCategoricalFeatures, - strategy.algo) + val numElementsPerNode = DecisionTree.getElementsPerNode(metadata, numBins) logDebug("numElementsPerNode = " + numElementsPerNode) val arraySizePerNode = 8 * numElementsPerNode // approx. memory usage for bin aggregate array @@ -114,9 +109,8 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo /* * The main idea here is to perform level-wise training of the decision tree nodes thus * reducing the passes over the data from l to log2(l) where l is the total number of nodes. - * Each data sample is checked for validity w.r.t to each node at a given level -- i.e., - * the sample is only used for the split calculation at the node if the sampled would have - * still survived the filters of the parent nodes. + * Each data sample is handled by a particular node at that level (or it reaches a leaf + * beforehand and is not used in later levels. */ var level = 0 @@ -130,22 +124,37 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo // Find best split for all nodes at a level. timer.start("findBestSplits") val splitsStatsForLevel = DecisionTree.findBestSplits(treeInput, parentImpurities, - strategy, level, filters, splits, bins, maxLevelForSingleGroup, timer) + metadata, level, nodes, splits, bins, maxLevelForSingleGroup, timer) timer.stop("findBestSplits") + val levelNodeIndexOffset = (1 << level) - 1 for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) { + val nodeIndex = levelNodeIndexOffset + index + val isLeftChild = level != 0 && nodeIndex % 2 == 1 + val parentNodeIndex = if (isLeftChild) { // -1 for root node + (nodeIndex - 1) / 2 + } else { + (nodeIndex - 2) / 2 + } + // Extract info for this node (index) at the current level. timer.start("extractNodeInfo") - // Extract info for nodes at the current level. extractNodeInfo(nodeSplitStats, level, index, nodes) timer.stop("extractNodeInfo") - timer.start("extractInfoForLowerLevels") + if (level != 0) { + // Set parent. + if (isLeftChild) { + nodes(parentNodeIndex).leftNode = Some(nodes(nodeIndex)) + } else { + nodes(parentNodeIndex).rightNode = Some(nodes(nodeIndex)) + } + } // Extract info for nodes at the next lower level. - extractInfoForLowerLevels(level, index, maxDepth, nodeSplitStats, parentImpurities, - filters) + timer.start("extractInfoForLowerLevels") + extractInfoForLowerLevels(level, index, maxDepth, nodeSplitStats, parentImpurities) timer.stop("extractInfoForLowerLevels") logDebug("final best split = " + nodeSplitStats._1) } - require(math.pow(2, level) == splitsStatsForLevel.length) + require((1 << level) == splitsStatsForLevel.length) // Check whether all the nodes at the current level at leaves. val allLeaf = splitsStatsForLevel.forall(_._2.gain <= 0) logDebug("all leaf = " + allLeaf) @@ -183,7 +192,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo nodes: Array[Node]): Unit = { val split = nodeSplitStats._1 val stats = nodeSplitStats._2 - val nodeIndex = math.pow(2, level).toInt - 1 + index + val nodeIndex = (1 << level) - 1 + index val isLeaf = (stats.gain <= 0) || (level == strategy.maxDepth) val node = new Node(nodeIndex, stats.predict, isLeaf, Some(split), None, None, Some(stats)) logDebug("Node = " + node) @@ -198,31 +207,21 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo index: Int, maxDepth: Int, nodeSplitStats: (Split, InformationGainStats), - parentImpurities: Array[Double], - filters: Array[List[Filter]]): Unit = { - // 0 corresponds to the left child node and 1 corresponds to the right child node. - var i = 0 - while (i <= 1) { - // Calculate the index of the node from the node level and the index at the current level. - val nodeIndex = math.pow(2, level + 1).toInt - 1 + 2 * index + i - if (level < maxDepth) { - val impurity = if (i == 0) { - nodeSplitStats._2.leftImpurity - } else { - nodeSplitStats._2.rightImpurity - } - logDebug("nodeIndex = " + nodeIndex + ", impurity = " + impurity) - // noting the parent impurities - parentImpurities(nodeIndex) = impurity - // noting the parents filters for the child nodes - val childFilter = new Filter(nodeSplitStats._1, if (i == 0) -1 else 1) - filters(nodeIndex) = childFilter :: filters((nodeIndex - 1) / 2) - for (filter <- filters(nodeIndex)) { - logDebug("Filter = " + filter) - } - } - i += 1 + parentImpurities: Array[Double]): Unit = { + + if (level >= maxDepth) { + return } + + val leftNodeIndex = (2 << level) - 1 + 2 * index + val leftImpurity = nodeSplitStats._2.leftImpurity + logDebug("leftNodeIndex = " + leftNodeIndex + ", impurity = " + leftImpurity) + parentImpurities(leftNodeIndex) = leftImpurity + + val rightNodeIndex = leftNodeIndex + 1 + val rightImpurity = nodeSplitStats._2.rightImpurity + logDebug("rightNodeIndex = " + rightNodeIndex + ", impurity = " + rightImpurity) + parentImpurities(rightNodeIndex) = rightImpurity } } @@ -434,10 +433,8 @@ object DecisionTree extends Serializable with Logging { * * @param input Training data: RDD of [[org.apache.spark.mllib.tree.impl.TreePoint]] * @param parentImpurities Impurities for all parent nodes for the current level - * @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing - * parameters for constructing the DecisionTree + * @param metadata Learning and dataset metadata * @param level Level of the tree - * @param filters Filters for all nodes at a given level * @param splits possible splits for all features * @param bins possible bins for all features * @param maxLevelForSingleGroup the deepest level for single-group level-wise computation. @@ -446,9 +443,9 @@ object DecisionTree extends Serializable with Logging { protected[tree] def findBestSplits( input: RDD[TreePoint], parentImpurities: Array[Double], - strategy: Strategy, + metadata: DecisionTreeMetadata, level: Int, - filters: Array[List[Filter]], + nodes: Array[Node], splits: Array[Array[Split]], bins: Array[Array[Bin]], maxLevelForSingleGroup: Int, @@ -459,34 +456,32 @@ object DecisionTree extends Serializable with Logging { // the nodes are divided into multiple groups at each level with the number of groups // increasing exponentially per level. For example, if maxLevelForSingleGroup is 10, // numGroups is equal to 2 at level 11 and 4 at level 12, respectively. - val numGroups = math.pow(2, level - maxLevelForSingleGroup).toInt + val numGroups = 1 << level - maxLevelForSingleGroup logDebug("numGroups = " + numGroups) var bestSplits = new Array[(Split, InformationGainStats)](0) // Iterate over each group of nodes at a level. var groupIndex = 0 while (groupIndex < numGroups) { - val bestSplitsForGroup = findBestSplitsPerGroup(input, parentImpurities, strategy, level, - filters, splits, bins, timer, numGroups, groupIndex) + val bestSplitsForGroup = findBestSplitsPerGroup(input, parentImpurities, metadata, level, + nodes, splits, bins, timer, numGroups, groupIndex) bestSplits = Array.concat(bestSplits, bestSplitsForGroup) groupIndex += 1 } bestSplits } else { - findBestSplitsPerGroup(input, parentImpurities, strategy, level, filters, splits, bins, timer) + findBestSplitsPerGroup(input, parentImpurities, metadata, level, nodes, splits, bins, timer) } } - /** + /** * Returns an array of optimal splits for a group of nodes at a given level * * @param input Training data: RDD of [[org.apache.spark.mllib.tree.impl.TreePoint]] * @param parentImpurities Impurities for all parent nodes for the current level - * @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing - * parameters for constructing the DecisionTree + * @param metadata Learning and dataset metadata * @param level Level of the tree - * @param filters Filters for all nodes at a given level * @param splits possible splits for all features - * @param bins possible bins for all features + * @param bins possible bins for all features, indexed as (numFeatures)(numBins) * @param numGroups total number of node groups at the current level. Default value is set to 1. * @param groupIndex index of the node group being processed. Default value is set to 0. * @return array of splits with best splits for all nodes at a given level. @@ -494,9 +489,9 @@ object DecisionTree extends Serializable with Logging { private def findBestSplitsPerGroup( input: RDD[TreePoint], parentImpurities: Array[Double], - strategy: Strategy, + metadata: DecisionTreeMetadata, level: Int, - filters: Array[List[Filter]], + nodes: Array[Node], splits: Array[Array[Split]], bins: Array[Array[Bin]], timer: TimeTracker, @@ -515,7 +510,7 @@ object DecisionTree extends Serializable with Logging { * We use a bin-wise best split computation strategy instead of a straightforward best split * computation strategy. Instead of analyzing each sample for contribution to the left/right * child node impurity of every split, we first categorize each feature of a sample into a - * bin. Each bin is an interval between a low and high split. Since each splits, and thus bin, + * bin. Each bin is an interval between a low and high split. Since each split, and thus bin, * is ordered (read ordering for categorical variables in the findSplitsBins method), * we exploit this structure to calculate aggregates for bins and then use these aggregates * to calculate information gain for each split. @@ -531,160 +526,124 @@ object DecisionTree extends Serializable with Logging { // numNodes: Number of nodes in this (level of tree, group), // where nodes at deeper (larger) levels may be divided into groups. - val numNodes = math.pow(2, level).toInt / numGroups + val numNodes = (1 << level) / numGroups logDebug("numNodes = " + numNodes) // Find the number of features by looking at the first sample. - val numFeatures = input.first().binnedFeatures.size + val numFeatures = metadata.numFeatures logDebug("numFeatures = " + numFeatures) // numBins: Number of bins = 1 + number of possible splits val numBins = bins(0).length logDebug("numBins = " + numBins) - val numClasses = strategy.numClassesForClassification + val numClasses = metadata.numClasses logDebug("numClasses = " + numClasses) - val isMulticlassClassification = strategy.isMulticlassClassification - logDebug("isMulticlassClassification = " + isMulticlassClassification) + val isMulticlass = metadata.isMulticlass + logDebug("isMulticlass = " + isMulticlass) - val isMulticlassClassificationWithCategoricalFeatures - = strategy.isMulticlassWithCategoricalFeatures - logDebug("isMultiClassWithCategoricalFeatures = " + - isMulticlassClassificationWithCategoricalFeatures) + val isMulticlassWithCategoricalFeatures = metadata.isMulticlassWithCategoricalFeatures + logDebug("isMultiClassWithCategoricalFeatures = " + isMulticlassWithCategoricalFeatures) // shift when more than one group is used at deep tree level val groupShift = numNodes * groupIndex - /** Find the filters used before reaching the current code. */ - def findParentFilters(nodeIndex: Int): List[Filter] = { - if (level == 0) { - List[Filter]() - } else { - val nodeFilterIndex = math.pow(2, level).toInt - 1 + nodeIndex + groupShift - filters(nodeFilterIndex) - } - } - /** - * Find whether the sample is valid input for the current node, i.e., whether it passes through - * all the filters for the current node. + * Get the node index corresponding to this data point. + * This function mimics prediction, passing an example from the root node down to a node + * at the current level being trained; that node's index is returned. + * + * @return Leaf index if the data point reaches a leaf. + * Otherwise, last node reachable in tree matching this example. */ - def isSampleValid(parentFilters: List[Filter], treePoint: TreePoint): Boolean = { - // leaf - if ((level > 0) && (parentFilters.length == 0)) { - return false - } - - // Apply each filter and check sample validity. Return false when invalid condition found. - parentFilters.foreach { filter => - val featureIndex = filter.split.feature - val comparison = filter.comparison - val isFeatureContinuous = filter.split.featureType == Continuous - if (isFeatureContinuous) { - val binId = treePoint.binnedFeatures(featureIndex) - val bin = bins(featureIndex)(binId) - val featureValue = bin.highSplit.threshold - val threshold = filter.split.threshold - comparison match { - case -1 => if (featureValue > threshold) return false - case 1 => if (featureValue <= threshold) return false + def predictNodeIndex(node: Node, binnedFeatures: Array[Int]): Int = { + if (node.isLeaf) { + node.id + } else { + val featureIndex = node.split.get.feature + val splitLeft = node.split.get.featureType match { + case Continuous => { + val binIndex = binnedFeatures(featureIndex) + val featureValueUpperBound = bins(featureIndex)(binIndex).highSplit.threshold + // bin binIndex has range (bin.lowSplit.threshold, bin.highSplit.threshold] + // We do not need to check lowSplit since bins are separated by splits. + featureValueUpperBound <= node.split.get.threshold } - } else { - val numFeatureCategories = strategy.categoricalFeaturesInfo(featureIndex) - val isSpaceSufficientForAllCategoricalSplits = - numBins > math.pow(2, numFeatureCategories.toInt - 1) - 1 - val isUnorderedFeature = - isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits - val featureValue = if (isUnorderedFeature) { - treePoint.binnedFeatures(featureIndex) + case Categorical => { + val featureValue = if (metadata.isUnordered(featureIndex)) { + binnedFeatures(featureIndex) + } else { + val binIndex = binnedFeatures(featureIndex) + bins(featureIndex)(binIndex).category + } + node.split.get.categories.contains(featureValue) + } + case _ => throw new RuntimeException(s"predictNodeIndex failed for unknown reason.") + } + if (node.leftNode.isEmpty || node.rightNode.isEmpty) { + // Return index from next layer of nodes to train + if (splitLeft) { + node.id * 2 + 1 // left } else { - val binId = treePoint.binnedFeatures(featureIndex) - bins(featureIndex)(binId).category + node.id * 2 + 2 // right } - val containsFeature = filter.split.categories.contains(featureValue) - comparison match { - case -1 => if (!containsFeature) return false - case 1 => if (containsFeature) return false + } else { + if (splitLeft) { + predictNodeIndex(node.leftNode.get, binnedFeatures) + } else { + predictNodeIndex(node.rightNode.get, binnedFeatures) } } } + } - // Return true when the sample is valid for all filters. - true + def nodeIndexToLevel(idx: Int): Int = { + if (idx == 0) { + 0 + } else { + math.floor(math.log(idx) / math.log(2)).toInt + } } + // Used for treePointToNodeIndex + val levelOffset = (1 << level) - 1 + /** - * Finds bins for all nodes (and all features) at a given level. - * For l nodes, k features the storage is as follows: - * label, b_11, b_12, .. , b_1k, b_21, b_22, .. , b_2k, b_l1, b_l2, .. , b_lk, - * where b_ij is an integer between 0 and numBins - 1 for regressions and binary - * classification and the categorical feature value in multiclass classification. - * Invalid sample is denoted by noting bin for feature 1 as -1. - * - * For unordered features, the "bin index" returned is actually the feature value (category). - * - * @return Array of size 1 + numFeatures * numNodes, where - * arr(0) = label for labeledPoint, and - * arr(1 + numFeatures * nodeIndex + featureIndex) = - * bin index for this labeledPoint - * (or InvalidBinIndex if labeledPoint is not handled by this node) + * Find the node index for the given example. + * Nodes are indexed from 0 at the start of this (level, group). + * If the example does not reach this level, returns a value < 0. */ - def findBinsForLevel(treePoint: TreePoint): Array[Double] = { - // Calculate bin index and label per feature per node. - val arr = new Array[Double](1 + (numFeatures * numNodes)) - // First element of the array is the label of the instance. - arr(0) = treePoint.label - // Iterate over nodes. - var nodeIndex = 0 - while (nodeIndex < numNodes) { - val parentFilters = findParentFilters(nodeIndex) - // Find out whether the sample qualifies for the particular node. - val sampleValid = isSampleValid(parentFilters, treePoint) - val shift = 1 + numFeatures * nodeIndex - if (!sampleValid) { - // Mark one bin as -1 is sufficient. - arr(shift) = InvalidBinIndex - } else { - var featureIndex = 0 - while (featureIndex < numFeatures) { - arr(shift + featureIndex) = treePoint.binnedFeatures(featureIndex) - featureIndex += 1 - } - } - nodeIndex += 1 + def treePointToNodeIndex(treePoint: TreePoint): Int = { + if (level == 0) { + 0 + } else { + val globalNodeIndex = predictNodeIndex(nodes(0), treePoint.binnedFeatures) + // Get index for this (level, group). + globalNodeIndex - levelOffset - groupShift } - arr } - // Find feature bins for all nodes at a level. - timer.start("aggregation") - val binMappedRDD = input.map(x => findBinsForLevel(x)) - /** * Increment aggregate in location for (node, feature, bin, label). * - * @param arr Bin mapping from findBinsForLevel. arr(0) stores the class label. - * Array of size 1 + (numFeatures * numNodes). + * @param treePoint Data point being aggregated. * @param agg Array storing aggregate calculation, of size: * numClasses * numBins * numFeatures * numNodes. * Indexed by (node, feature, bin, label) where label is the least significant bit. + * @param nodeIndex Node corresponding to treePoint. Indexed from 0 at start of (level, group). */ def updateBinForOrderedFeature( - arr: Array[Double], + treePoint: TreePoint, agg: Array[Double], nodeIndex: Int, - label: Double, featureIndex: Int): Unit = { - // Find the bin index for this feature. - val arrShift = 1 + numFeatures * nodeIndex - val arrIndex = arrShift + featureIndex // Update the left or right count for one bin. val aggIndex = numClasses * numBins * numFeatures * nodeIndex + numClasses * numBins * featureIndex + - numClasses * arr(arrIndex).toInt + - label.toInt + numClasses * treePoint.binnedFeatures(featureIndex) + + treePoint.label.toInt agg(aggIndex) += 1 } @@ -693,8 +652,8 @@ object DecisionTree extends Serializable with Logging { * where [bins] ranges over all bins. * Updates left or right side of aggregate depending on split. * - * @param arr arr(0) = label. - * arr(1 + featureIndex + nodeIndex * numFeatures) = feature value (category) + * @param nodeIndex Node corresponding to treePoint. Indexed from 0 at start of (level, group). + * @param treePoint Data point being aggregated. * @param agg Indexed by (left/right, node, feature, bin, label) * where label is the least significant bit. * The left/right specifier is a 0/1 index indicating left/right child info. @@ -703,21 +662,18 @@ object DecisionTree extends Serializable with Logging { def updateBinForUnorderedFeature( nodeIndex: Int, featureIndex: Int, - arr: Array[Double], - label: Double, + treePoint: TreePoint, agg: Array[Double], rightChildShift: Int): Unit = { - // Find the bin index for this feature. - val arrIndex = 1 + numFeatures * nodeIndex + featureIndex - val featureValue = arr(arrIndex).toInt + val featureValue = treePoint.binnedFeatures(featureIndex) // Update the left or right count for one bin. val aggShift = numClasses * numBins * numFeatures * nodeIndex + numClasses * numBins * featureIndex + - label.toInt + treePoint.label.toInt // Find all matching bins and increment their values - val featureCategories = strategy.categoricalFeaturesInfo(featureIndex) - val numCategoricalBins = math.pow(2.0, featureCategories - 1).toInt - 1 + val featureCategories = metadata.featureArity(featureIndex) + val numCategoricalBins = (1 << featureCategories - 1) - 1 var binIndex = 0 while (binIndex < numCategoricalBins) { val aggIndex = aggShift + binIndex * numClasses @@ -733,30 +689,21 @@ object DecisionTree extends Serializable with Logging { /** * Helper for binSeqOp. * - * @param arr Bin mapping from findBinsForLevel. arr(0) stores the class label. - * Array of size 1 + (numFeatures * numNodes). * @param agg Array storing aggregate calculation, of size: * numClasses * numBins * numFeatures * numNodes. * Indexed by (node, feature, bin, label) where label is the least significant bit. + * @param treePoint Data point being aggregated. + * @param nodeIndex Node corresponding to treePoint. Indexed from 0 at start of (level, group). */ - def binaryOrNotCategoricalBinSeqOp(arr: Array[Double], agg: Array[Double]): Unit = { - // Iterate over all nodes. - var nodeIndex = 0 - while (nodeIndex < numNodes) { - // Check whether the instance was valid for this nodeIndex. - val validSignalIndex = 1 + numFeatures * nodeIndex - val isSampleValidForNode = arr(validSignalIndex) != InvalidBinIndex - if (isSampleValidForNode) { - // actual class label - val label = arr(0) - // Iterate over all features. - var featureIndex = 0 - while (featureIndex < numFeatures) { - updateBinForOrderedFeature(arr, agg, nodeIndex, label, featureIndex) - featureIndex += 1 - } - } - nodeIndex += 1 + def binaryOrNotCategoricalBinSeqOp( + agg: Array[Double], + treePoint: TreePoint, + nodeIndex: Int): Unit = { + // Iterate over all features. + var featureIndex = 0 + while (featureIndex < numFeatures) { + updateBinForOrderedFeature(treePoint, agg, nodeIndex, featureIndex) + featureIndex += 1 } } @@ -765,49 +712,28 @@ object DecisionTree extends Serializable with Logging { /** * Helper for binSeqOp. * - * @param arr Bin mapping from findBinsForLevel. arr(0) stores the class label. - * Array of size 1 + (numFeatures * numNodes). - * For ordered features, - * arr(1 + featureIndex + nodeIndex * numFeatures) = bin index. - * For unordered features, - * arr(1 + featureIndex + nodeIndex * numFeatures) = feature value (category). * @param agg Array storing aggregate calculation. * For ordered features, this is of size: * numClasses * numBins * numFeatures * numNodes. * For unordered features, this is of size: * 2 * numClasses * numBins * numFeatures * numNodes. + * @param treePoint Data point being aggregated. + * @param nodeIndex Node corresponding to treePoint. Indexed from 0 at start of (level, group). */ - def multiclassWithCategoricalBinSeqOp(arr: Array[Double], agg: Array[Double]): Unit = { - // Iterate over all nodes. - var nodeIndex = 0 - while (nodeIndex < numNodes) { - // Check whether the instance was valid for this nodeIndex. - val validSignalIndex = 1 + numFeatures * nodeIndex - val isSampleValidForNode = arr(validSignalIndex) != InvalidBinIndex - if (isSampleValidForNode) { - // actual class label - val label = arr(0) - // Iterate over all features. - var featureIndex = 0 - while (featureIndex < numFeatures) { - val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty - if (isFeatureContinuous) { - updateBinForOrderedFeature(arr, agg, nodeIndex, label, featureIndex) - } else { - val featureCategories = strategy.categoricalFeaturesInfo(featureIndex) - val isSpaceSufficientForAllCategoricalSplits - = numBins > math.pow(2, featureCategories.toInt - 1) - 1 - if (isSpaceSufficientForAllCategoricalSplits) { - updateBinForUnorderedFeature(nodeIndex, featureIndex, arr, label, agg, - rightChildShift) - } else { - updateBinForOrderedFeature(arr, agg, nodeIndex, label, featureIndex) - } - } - featureIndex += 1 - } + def multiclassWithCategoricalBinSeqOp( + agg: Array[Double], + treePoint: TreePoint, + nodeIndex: Int): Unit = { + val label = treePoint.label + // Iterate over all features. + var featureIndex = 0 + while (featureIndex < numFeatures) { + if (metadata.isUnordered(featureIndex)) { + updateBinForUnorderedFeature(nodeIndex, featureIndex, treePoint, agg, rightChildShift) + } else { + updateBinForOrderedFeature(treePoint, agg, nodeIndex, featureIndex) } - nodeIndex += 1 + featureIndex += 1 } } @@ -818,36 +744,25 @@ object DecisionTree extends Serializable with Logging { * * @param agg Array storing aggregate calculation, updated by this function. * Size: 3 * numBins * numFeatures * numNodes - * @param arr Bin mapping from findBinsForLevel. - * Array of size 1 + (numFeatures * numNodes). + * @param treePoint Data point being aggregated. + * @param nodeIndex Node corresponding to treePoint. Indexed from 0 at start of (level, group). * @return agg */ - def regressionBinSeqOp(arr: Array[Double], agg: Array[Double]): Unit = { - // Iterate over all nodes. - var nodeIndex = 0 - while (nodeIndex < numNodes) { - // Check whether the instance was valid for this nodeIndex. - val validSignalIndex = 1 + numFeatures * nodeIndex - val isSampleValidForNode = arr(validSignalIndex) != InvalidBinIndex - if (isSampleValidForNode) { - // actual class label - val label = arr(0) - // Iterate over all features. - var featureIndex = 0 - while (featureIndex < numFeatures) { - // Find the bin index for this feature. - val arrShift = 1 + numFeatures * nodeIndex - val arrIndex = arrShift + featureIndex - // Update count, sum, and sum^2 for one bin. - val aggShift = 3 * numBins * numFeatures * nodeIndex - val aggIndex = aggShift + 3 * featureIndex * numBins + arr(arrIndex).toInt * 3 - agg(aggIndex) = agg(aggIndex) + 1 - agg(aggIndex + 1) = agg(aggIndex + 1) + label - agg(aggIndex + 2) = agg(aggIndex + 2) + label * label - featureIndex += 1 - } - } - nodeIndex += 1 + def regressionBinSeqOp(agg: Array[Double], treePoint: TreePoint, nodeIndex: Int): Unit = { + val label = treePoint.label + // Iterate over all features. + var featureIndex = 0 + while (featureIndex < numFeatures) { + // Update count, sum, and sum^2 for one bin. + val binIndex = treePoint.binnedFeatures(featureIndex) + val aggIndex = + 3 * numBins * numFeatures * nodeIndex + + 3 * numBins * featureIndex + + 3 * binIndex + agg(aggIndex) += 1 + agg(aggIndex + 1) += label + agg(aggIndex + 2) += label * label + featureIndex += 1 } } @@ -866,26 +781,30 @@ object DecisionTree extends Serializable with Logging { * 2 * numClasses * numBins * numFeatures * numNodes for unordered features. * Size for regression: * 3 * numBins * numFeatures * numNodes. - * @param arr Bin mapping from findBinsForLevel. - * Array of size 1 + (numFeatures * numNodes). + * @param treePoint Data point being aggregated. * @return agg */ - def binSeqOp(agg: Array[Double], arr: Array[Double]): Array[Double] = { - strategy.algo match { - case Classification => - if(isMulticlassClassificationWithCategoricalFeatures) { - multiclassWithCategoricalBinSeqOp(arr, agg) + def binSeqOp(agg: Array[Double], treePoint: TreePoint): Array[Double] = { + val nodeIndex = treePointToNodeIndex(treePoint) + // If the example does not reach this level, then nodeIndex < 0. + // If the example reaches this level but is handled in a different group, + // then either nodeIndex < 0 (previous group) or nodeIndex >= numNodes (later group). + if (nodeIndex >= 0 && nodeIndex < numNodes) { + if (metadata.isClassification) { + if (isMulticlassWithCategoricalFeatures) { + multiclassWithCategoricalBinSeqOp(agg, treePoint, nodeIndex) } else { - binaryOrNotCategoricalBinSeqOp(arr, agg) + binaryOrNotCategoricalBinSeqOp(agg, treePoint, nodeIndex) } - case Regression => regressionBinSeqOp(arr, agg) + } else { + regressionBinSeqOp(agg, treePoint, nodeIndex) + } } agg } // Calculate bin aggregate length for classification or regression. - val binAggregateLength = numNodes * getElementsPerNode(numFeatures, numBins, numClasses, - isMulticlassClassificationWithCategoricalFeatures, strategy.algo) + val binAggregateLength = numNodes * getElementsPerNode(metadata, numBins) logDebug("binAggregateLength = " + binAggregateLength) /** @@ -905,144 +824,134 @@ object DecisionTree extends Serializable with Logging { } // Calculate bin aggregates. + timer.start("aggregation") val binAggregates = { - binMappedRDD.aggregate(Array.fill[Double](binAggregateLength)(0))(binSeqOp,binCombOp) + input.aggregate(Array.fill[Double](binAggregateLength)(0))(binSeqOp, binCombOp) } timer.stop("aggregation") logDebug("binAggregates.length = " + binAggregates.length) /** - * Calculates the information gain for all splits based upon left/right split aggregates. - * @param leftNodeAgg left node aggregates - * @param featureIndex feature index - * @param splitIndex split index - * @param rightNodeAgg right node aggregate + * Calculate the information gain for a given (feature, split) based upon left/right aggregates. + * @param leftNodeAgg left node aggregates for this (feature, split) + * @param rightNodeAgg right node aggregate for this (feature, split) * @param topImpurity impurity of the parent node * @return information gain and statistics for all splits */ def calculateGainForSplit( - leftNodeAgg: Array[Array[Array[Double]]], - featureIndex: Int, - splitIndex: Int, - rightNodeAgg: Array[Array[Array[Double]]], + leftNodeAgg: Array[Double], + rightNodeAgg: Array[Double], topImpurity: Double): InformationGainStats = { - strategy.algo match { - case Classification => - val leftCounts: Array[Double] = leftNodeAgg(featureIndex)(splitIndex) - val rightCounts: Array[Double] = rightNodeAgg(featureIndex)(splitIndex) - val leftTotalCount = leftCounts.sum - val rightTotalCount = rightCounts.sum - - val impurity = { - if (level > 0) { - topImpurity - } else { - // Calculate impurity for root node. - val rootNodeCounts = new Array[Double](numClasses) - var classIndex = 0 - while (classIndex < numClasses) { - rootNodeCounts(classIndex) = leftCounts(classIndex) + rightCounts(classIndex) - classIndex += 1 - } - strategy.impurity.calculate(rootNodeCounts, leftTotalCount + rightTotalCount) - } - } + if (metadata.isClassification) { + val leftTotalCount = leftNodeAgg.sum + val rightTotalCount = rightNodeAgg.sum - val totalCount = leftTotalCount + rightTotalCount - if (totalCount == 0) { - // Return arbitrary prediction. - return new InformationGainStats(0, topImpurity, topImpurity, topImpurity, 0) + val impurity = { + if (level > 0) { + topImpurity + } else { + // Calculate impurity for root node. + val rootNodeCounts = new Array[Double](numClasses) + var classIndex = 0 + while (classIndex < numClasses) { + rootNodeCounts(classIndex) = leftNodeAgg(classIndex) + rightNodeAgg(classIndex) + classIndex += 1 + } + metadata.impurity.calculate(rootNodeCounts, leftTotalCount + rightTotalCount) } + } - // Sum of count for each label - val leftRightCounts: Array[Double] = - leftCounts.zip(rightCounts).map { case (leftCount, rightCount) => - leftCount + rightCount - } + val totalCount = leftTotalCount + rightTotalCount + if (totalCount == 0) { + // Return arbitrary prediction. + return new InformationGainStats(0, topImpurity, topImpurity, topImpurity, 0) + } - def indexOfLargestArrayElement(array: Array[Double]): Int = { - val result = array.foldLeft(-1, Double.MinValue, 0) { - case ((maxIndex, maxValue, currentIndex), currentValue) => - if (currentValue > maxValue) { - (currentIndex, currentValue, currentIndex + 1) - } else { - (maxIndex, maxValue, currentIndex + 1) - } - } - if (result._1 < 0) { - throw new RuntimeException("DecisionTree internal error:" + - " calculateGainForSplit failed in indexOfLargestArrayElement") - } - result._1 + // Sum of count for each label + val leftrightNodeAgg: Array[Double] = + leftNodeAgg.zip(rightNodeAgg).map { case (leftCount, rightCount) => + leftCount + rightCount } - val predict = indexOfLargestArrayElement(leftRightCounts) - val prob = leftRightCounts(predict) / totalCount - - val leftImpurity = if (leftTotalCount == 0) { - topImpurity - } else { - strategy.impurity.calculate(leftCounts, leftTotalCount) + def indexOfLargestArrayElement(array: Array[Double]): Int = { + val result = array.foldLeft(-1, Double.MinValue, 0) { + case ((maxIndex, maxValue, currentIndex), currentValue) => + if (currentValue > maxValue) { + (currentIndex, currentValue, currentIndex + 1) + } else { + (maxIndex, maxValue, currentIndex + 1) + } } - val rightImpurity = if (rightTotalCount == 0) { - topImpurity - } else { - strategy.impurity.calculate(rightCounts, rightTotalCount) + if (result._1 < 0) { + throw new RuntimeException("DecisionTree internal error:" + + " calculateGainForSplit failed in indexOfLargestArrayElement") } + result._1 + } - val leftWeight = leftTotalCount / totalCount - val rightWeight = rightTotalCount / totalCount + val predict = indexOfLargestArrayElement(leftrightNodeAgg) + val prob = leftrightNodeAgg(predict) / totalCount - val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity + val leftImpurity = if (leftTotalCount == 0) { + topImpurity + } else { + metadata.impurity.calculate(leftNodeAgg, leftTotalCount) + } + val rightImpurity = if (rightTotalCount == 0) { + topImpurity + } else { + metadata.impurity.calculate(rightNodeAgg, rightTotalCount) + } - new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict, prob) + val leftWeight = leftTotalCount / totalCount + val rightWeight = rightTotalCount / totalCount - case Regression => - val leftCount = leftNodeAgg(featureIndex)(splitIndex)(0) - val leftSum = leftNodeAgg(featureIndex)(splitIndex)(1) - val leftSumSquares = leftNodeAgg(featureIndex)(splitIndex)(2) + val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity - val rightCount = rightNodeAgg(featureIndex)(splitIndex)(0) - val rightSum = rightNodeAgg(featureIndex)(splitIndex)(1) - val rightSumSquares = rightNodeAgg(featureIndex)(splitIndex)(2) + new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict, prob) - val impurity = { - if (level > 0) { - topImpurity - } else { - // Calculate impurity for root node. - val count = leftCount + rightCount - val sum = leftSum + rightSum - val sumSquares = leftSumSquares + rightSumSquares - strategy.impurity.calculate(count, sum, sumSquares) - } - } + } else { + // Regression - if (leftCount == 0) { - return new InformationGainStats(0, topImpurity, Double.MinValue, topImpurity, - rightSum / rightCount) - } - if (rightCount == 0) { - return new InformationGainStats(0, topImpurity ,topImpurity, - Double.MinValue, leftSum / leftCount) + val leftCount = leftNodeAgg(0) + val leftSum = leftNodeAgg(1) + val leftSumSquares = leftNodeAgg(2) + + val rightCount = rightNodeAgg(0) + val rightSum = rightNodeAgg(1) + val rightSumSquares = rightNodeAgg(2) + + val impurity = { + if (level > 0) { + topImpurity + } else { + // Calculate impurity for root node. + val count = leftCount + rightCount + val sum = leftSum + rightSum + val sumSquares = leftSumSquares + rightSumSquares + metadata.impurity.calculate(count, sum, sumSquares) } + } + + if (leftCount == 0) { + return new InformationGainStats(0, topImpurity, Double.MinValue, topImpurity, + rightSum / rightCount) + } + if (rightCount == 0) { + return new InformationGainStats(0, topImpurity, topImpurity, + Double.MinValue, leftSum / leftCount) + } - val leftImpurity = strategy.impurity.calculate(leftCount, leftSum, leftSumSquares) - val rightImpurity = strategy.impurity.calculate(rightCount, rightSum, rightSumSquares) + val leftImpurity = metadata.impurity.calculate(leftCount, leftSum, leftSumSquares) + val rightImpurity = metadata.impurity.calculate(rightCount, rightSum, rightSumSquares) - val leftWeight = leftCount.toDouble / (leftCount + rightCount) - val rightWeight = rightCount.toDouble / (leftCount + rightCount) + val leftWeight = leftCount.toDouble / (leftCount + rightCount) + val rightWeight = rightCount.toDouble / (leftCount + rightCount) - val gain = { - if (level > 0) { - impurity - leftWeight * leftImpurity - rightWeight * rightImpurity - } else { - impurity - leftWeight * leftImpurity - rightWeight * rightImpurity - } - } + val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity - val predict = (leftSum + rightSum) / (leftCount + rightCount) - new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict) + val predict = (leftSum + rightSum) / (leftCount + rightCount) + new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict) } } @@ -1065,6 +974,19 @@ object DecisionTree extends Serializable with Logging { binData: Array[Double]): (Array[Array[Array[Double]]], Array[Array[Array[Double]]]) = { + /** + * The input binData is indexed as (feature, bin, class). + * This computes cumulative sums over splits. + * Each (feature, class) pair is handled separately. + * Note: numSplits = numBins - 1. + * @param leftNodeAgg Each (feature, class) slice is an array over splits. + * Element i (i = 0, ..., numSplits - 2) is set to be + * the cumulative sum (from left) over binData for bins 0, ..., i. + * @param rightNodeAgg Each (feature, class) slice is an array over splits. + * Element i (i = 1, ..., numSplits - 1) is set to be + * the cumulative sum (from right) over binData for bins + * numBins - 1, ..., numBins - 1 - i. + */ def findAggForOrderedFeatureClassification( leftNodeAgg: Array[Array[Array[Double]]], rightNodeAgg: Array[Array[Array[Double]]], @@ -1169,45 +1091,32 @@ object DecisionTree extends Serializable with Logging { } } - strategy.algo match { - case Classification => - // Initialize left and right split aggregates. - val leftNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, numClasses) - val rightNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, numClasses) - var featureIndex = 0 - while (featureIndex < numFeatures) { - if (isMulticlassClassificationWithCategoricalFeatures) { - val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty - if (isFeatureContinuous) { - findAggForOrderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex) - } else { - val featureCategories = strategy.categoricalFeaturesInfo(featureIndex) - val isSpaceSufficientForAllCategoricalSplits - = numBins > math.pow(2, featureCategories.toInt - 1) - 1 - if (isSpaceSufficientForAllCategoricalSplits) { - findAggForUnorderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex) - } else { - findAggForOrderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex) - } - } - } else { - findAggForOrderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex) - } - featureIndex += 1 - } - - (leftNodeAgg, rightNodeAgg) - case Regression => - // Initialize left and right split aggregates. - val leftNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, 3) - val rightNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, 3) - // Iterate over all features. - var featureIndex = 0 - while (featureIndex < numFeatures) { - findAggForRegression(leftNodeAgg, rightNodeAgg, featureIndex) - featureIndex += 1 + if (metadata.isClassification) { + // Initialize left and right split aggregates. + val leftNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, numClasses) + val rightNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, numClasses) + var featureIndex = 0 + while (featureIndex < numFeatures) { + if (metadata.isUnordered(featureIndex)) { + findAggForUnorderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex) + } else { + findAggForOrderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex) } - (leftNodeAgg, rightNodeAgg) + featureIndex += 1 + } + (leftNodeAgg, rightNodeAgg) + } else { + // Regression + // Initialize left and right split aggregates. + val leftNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, 3) + val rightNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, 3) + // Iterate over all features. + var featureIndex = 0 + while (featureIndex < numFeatures) { + findAggForRegression(leftNodeAgg, rightNodeAgg, featureIndex) + featureIndex += 1 + } + (leftNodeAgg, rightNodeAgg) } } @@ -1225,8 +1134,9 @@ object DecisionTree extends Serializable with Logging { val numSplitsForFeature = getNumSplitsForFeature(featureIndex) var splitIndex = 0 while (splitIndex < numSplitsForFeature) { - gains(featureIndex)(splitIndex) = calculateGainForSplit(leftNodeAgg, featureIndex, - splitIndex, rightNodeAgg, nodeImpurity) + gains(featureIndex)(splitIndex) = + calculateGainForSplit(leftNodeAgg(featureIndex)(splitIndex), + rightNodeAgg(featureIndex)(splitIndex), nodeImpurity) splitIndex += 1 } featureIndex += 1 @@ -1238,18 +1148,14 @@ object DecisionTree extends Serializable with Logging { * Get the number of splits for a feature. */ def getNumSplitsForFeature(featureIndex: Int): Int = { - val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty - if (isFeatureContinuous) { + if (metadata.isContinuous(featureIndex)) { numBins - 1 } else { // Categorical feature - val featureCategories = strategy.categoricalFeaturesInfo(featureIndex) - val isSpaceSufficientForAllCategoricalSplits = - numBins > math.pow(2, featureCategories.toInt - 1) - 1 - if (isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits) { - math.pow(2.0, featureCategories - 1).toInt - 1 + val featureCategories = metadata.featureArity(featureIndex) + if (metadata.isUnordered(featureIndex)) { + (1 << featureCategories - 1) - 1 } else { - // Ordered features featureCategories } } @@ -1308,29 +1214,29 @@ object DecisionTree extends Serializable with Logging { * Get bin data for one node. */ def getBinDataForNode(node: Int): Array[Double] = { - strategy.algo match { - case Classification => - if (isMulticlassClassificationWithCategoricalFeatures) { - val shift = numClasses * node * numBins * numFeatures - val rightChildShift = numClasses * numBins * numFeatures * numNodes - val binsForNode = { - val leftChildData - = binAggregates.slice(shift, shift + numClasses * numBins * numFeatures) - val rightChildData - = binAggregates.slice(rightChildShift + shift, - rightChildShift + shift + numClasses * numBins * numFeatures) - leftChildData ++ rightChildData - } - binsForNode - } else { - val shift = numClasses * node * numBins * numFeatures - val binsForNode = binAggregates.slice(shift, shift + numClasses * numBins * numFeatures) - binsForNode + if (metadata.isClassification) { + if (isMulticlassWithCategoricalFeatures) { + val shift = numClasses * node * numBins * numFeatures + val rightChildShift = numClasses * numBins * numFeatures * numNodes + val binsForNode = { + val leftChildData + = binAggregates.slice(shift, shift + numClasses * numBins * numFeatures) + val rightChildData + = binAggregates.slice(rightChildShift + shift, + rightChildShift + shift + numClasses * numBins * numFeatures) + leftChildData ++ rightChildData } - case Regression => - val shift = 3 * node * numBins * numFeatures - val binsForNode = binAggregates.slice(shift, shift + 3 * numBins * numFeatures) binsForNode + } else { + val shift = numClasses * node * numBins * numFeatures + val binsForNode = binAggregates.slice(shift, shift + numClasses * numBins * numFeatures) + binsForNode + } + } else { + // Regression + val shift = 3 * node * numBins * numFeatures + val binsForNode = binAggregates.slice(shift, shift + 3 * numBins * numFeatures) + binsForNode } } @@ -1340,7 +1246,7 @@ object DecisionTree extends Serializable with Logging { // Iterating over all nodes at this level var node = 0 while (node < numNodes) { - val nodeImpurityIndex = math.pow(2, level).toInt - 1 + node + groupShift + val nodeImpurityIndex = (1 << level) - 1 + node + groupShift val binsForNode: Array[Double] = getBinDataForNode(node) logDebug("nodeImpurityIndex = " + nodeImpurityIndex) val parentNodeImpurity = parentImpurities(nodeImpurityIndex) @@ -1358,20 +1264,15 @@ object DecisionTree extends Serializable with Logging { * * @param numBins Number of bins = 1 + number of possible splits. */ - private def getElementsPerNode( - numFeatures: Int, - numBins: Int, - numClasses: Int, - isMulticlassClassificationWithCategoricalFeatures: Boolean, - algo: Algo): Int = { - algo match { - case Classification => - if (isMulticlassClassificationWithCategoricalFeatures) { - 2 * numClasses * numBins * numFeatures - } else { - numClasses * numBins * numFeatures - } - case Regression => 3 * numBins * numFeatures + private def getElementsPerNode(metadata: DecisionTreeMetadata, numBins: Int): Int = { + if (metadata.isClassification) { + if (metadata.isMulticlassWithCategoricalFeatures) { + 2 * metadata.numClasses * numBins * metadata.numFeatures + } else { + metadata.numClasses * numBins * metadata.numFeatures + } + } else { + 3 * numBins * metadata.numFeatures } } @@ -1390,16 +1291,15 @@ object DecisionTree extends Serializable with Logging { * For multiclass classification with a low-arity feature * (i.e., if isMulticlass && isSpaceSufficientForAllCategoricalSplits), * the feature is split based on subsets of categories. - * There are math.pow(2, maxFeatureValue - 1) - 1 splits. + * There are (1 << maxFeatureValue - 1) - 1 splits. * (b) "ordered features" * For regression and binary classification, * and for multiclass classification with a high-arity feature, * there is one bin per category. * * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] - * @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing - * parameters for construction the DecisionTree - * @return A tuple of (splits,bins). + * @param metadata Learning and dataset metadata + * @return A tuple of (splits, bins). * Splits is an Array of [[org.apache.spark.mllib.tree.model.Split]] * of size (numFeatures, numBins - 1). * Bins is an Array of [[org.apache.spark.mllib.tree.model.Bin]] @@ -1407,19 +1307,18 @@ object DecisionTree extends Serializable with Logging { */ protected[tree] def findSplitsBins( input: RDD[LabeledPoint], - strategy: Strategy): (Array[Array[Split]], Array[Array[Bin]]) = { + metadata: DecisionTreeMetadata): (Array[Array[Split]], Array[Array[Bin]]) = { val count = input.count() // Find the number of features by looking at the first sample val numFeatures = input.take(1)(0).features.size - val maxBins = strategy.maxBins + val maxBins = metadata.maxBins val numBins = if (maxBins <= count) maxBins else count.toInt logDebug("numBins = " + numBins) - val isMulticlassClassification = strategy.isMulticlassClassification - logDebug("isMulticlassClassification = " + isMulticlassClassification) - + val isMulticlass = metadata.isMulticlass + logDebug("isMulticlass = " + isMulticlass) /* * Ensure numBins is always greater than the categories. For multiclass classification, @@ -1431,13 +1330,12 @@ object DecisionTree extends Serializable with Logging { * by the number of training examples. * TODO: Allow this case, where we simply will know nothing about some categories. */ - if (strategy.categoricalFeaturesInfo.size > 0) { - val maxCategoriesForFeatures = strategy.categoricalFeaturesInfo.maxBy(_._2)._2 + if (metadata.featureArity.size > 0) { + val maxCategoriesForFeatures = metadata.featureArity.maxBy(_._2)._2 require(numBins > maxCategoriesForFeatures, "numBins should be greater than max categories " + "in categorical features") } - // Calculate the number of sample for approximate quantile calculation. val requiredSamples = numBins*numBins val fraction = if (requiredSamples < count) requiredSamples.toDouble / count else 1.0 @@ -1451,7 +1349,7 @@ object DecisionTree extends Serializable with Logging { val stride: Double = numSamples.toDouble / numBins logDebug("stride = " + stride) - strategy.quantileCalculationStrategy match { + metadata.quantileStrategy match { case Sort => val splits = Array.ofDim[Split](numFeatures, numBins - 1) val bins = Array.ofDim[Bin](numFeatures, numBins) @@ -1462,7 +1360,7 @@ object DecisionTree extends Serializable with Logging { var featureIndex = 0 while (featureIndex < numFeatures) { // Check whether the feature is continuous. - val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty + val isFeatureContinuous = metadata.isContinuous(featureIndex) if (isFeatureContinuous) { val featureSamples = sampledInput.map(lp => lp.features(featureIndex)).sorted val stride: Double = numSamples.toDouble / numBins @@ -1475,18 +1373,14 @@ object DecisionTree extends Serializable with Logging { splits(featureIndex)(index) = split } } else { // Categorical feature - val featureCategories = strategy.categoricalFeaturesInfo(featureIndex) - val isSpaceSufficientForAllCategoricalSplits - = numBins > math.pow(2, featureCategories.toInt - 1) - 1 + val featureCategories = metadata.featureArity(featureIndex) // Use different bin/split calculation strategy for categorical features in multiclass // classification that satisfy the space constraint. - val isUnorderedFeature = - isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits - if (isUnorderedFeature) { + if (metadata.isUnordered(featureIndex)) { // 2^(maxFeatureValue- 1) - 1 combinations var index = 0 - while (index < math.pow(2.0, featureCategories - 1).toInt - 1) { + while (index < (1 << featureCategories - 1) - 1) { val categories: List[Double] = extractMultiClassCategories(index + 1, featureCategories) splits(featureIndex)(index) @@ -1516,7 +1410,7 @@ object DecisionTree extends Serializable with Logging { * centroidForCategories is a mapping: category (for the given feature) --> centroid */ val centroidForCategories = { - if (isMulticlassClassification) { + if (isMulticlass) { // For categorical variables in multiclass classification, // each bin is a category. The bins are sorted and they // are ordered by calculating the impurity of their corresponding labels. @@ -1524,7 +1418,7 @@ object DecisionTree extends Serializable with Logging { .groupBy(_._1) .mapValues(x => x.groupBy(_._2).mapValues(x => x.size.toDouble)) .map(x => (x._1, x._2.values.toArray)) - .map(x => (x._1, strategy.impurity.calculate(x._2, x._2.sum))) + .map(x => (x._1, metadata.impurity.calculate(x._2, x._2.sum))) } else { // regression or binary classification // For categorical variables in regression and binary classification, // each bin is a category. The bins are sorted and they @@ -1576,7 +1470,7 @@ object DecisionTree extends Serializable with Logging { // Find all bins. featureIndex = 0 while (featureIndex < numFeatures) { - val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty + val isFeatureContinuous = metadata.isContinuous(featureIndex) if (isFeatureContinuous) { // Bins for categorical variables are already assigned. bins(featureIndex)(0) = new Bin(new DummyLowSplit(featureIndex, Continuous), splits(featureIndex)(0), Continuous, Double.MinValue) @@ -1590,7 +1484,7 @@ object DecisionTree extends Serializable with Logging { } featureIndex += 1 } - (splits,bins) + (splits, bins) case MinMax => throw new UnsupportedOperationException("minmax not supported yet.") case ApproxHist => diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala new file mode 100644 index 0000000000000..d9eda354dc986 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.impl + +import scala.collection.mutable + +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.configuration.Algo._ +import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ +import org.apache.spark.mllib.tree.configuration.Strategy +import org.apache.spark.mllib.tree.impurity.Impurity +import org.apache.spark.rdd.RDD + + +/** + * Learning and dataset metadata for DecisionTree. + * + * @param numClasses For classification: labels can take values {0, ..., numClasses - 1}. + * For regression: fixed at 0 (no meaning). + * @param featureArity Map: categorical feature index --> arity. + * I.e., the feature takes values in {0, ..., arity - 1}. + */ +private[tree] class DecisionTreeMetadata( + val numFeatures: Int, + val numExamples: Long, + val numClasses: Int, + val maxBins: Int, + val featureArity: Map[Int, Int], + val unorderedFeatures: Set[Int], + val impurity: Impurity, + val quantileStrategy: QuantileStrategy) extends Serializable { + + def isUnordered(featureIndex: Int): Boolean = unorderedFeatures.contains(featureIndex) + + def isClassification: Boolean = numClasses >= 2 + + def isMulticlass: Boolean = numClasses > 2 + + def isMulticlassWithCategoricalFeatures: Boolean = isMulticlass && (featureArity.size > 0) + + def isCategorical(featureIndex: Int): Boolean = featureArity.contains(featureIndex) + + def isContinuous(featureIndex: Int): Boolean = !featureArity.contains(featureIndex) + +} + +private[tree] object DecisionTreeMetadata { + + def buildMetadata(input: RDD[LabeledPoint], strategy: Strategy): DecisionTreeMetadata = { + + val numFeatures = input.take(1)(0).features.size + val numExamples = input.count() + val numClasses = strategy.algo match { + case Classification => strategy.numClassesForClassification + case Regression => 0 + } + + val maxBins = math.min(strategy.maxBins, numExamples).toInt + val log2MaxBinsp1 = math.log(maxBins + 1) / math.log(2.0) + + val unorderedFeatures = new mutable.HashSet[Int]() + if (numClasses > 2) { + strategy.categoricalFeaturesInfo.foreach { case (f, k) => + if (k - 1 < log2MaxBinsp1) { + // Note: The above check is equivalent to checking: + // numUnorderedBins = (1 << k - 1) - 1 < maxBins + unorderedFeatures.add(f) + } else { + // TODO: Allow this case, where we simply will know nothing about some categories? + require(k < maxBins, s"maxBins (= $maxBins) should be greater than max categories " + + s"in categorical features (>= $k)") + } + } + } else { + strategy.categoricalFeaturesInfo.foreach { case (f, k) => + require(k < maxBins, s"maxBins (= $maxBins) should be greater than max categories " + + s"in categorical features (>= $k)") + } + } + + new DecisionTreeMetadata(numFeatures, numExamples, numClasses, maxBins, + strategy.categoricalFeaturesInfo, unorderedFeatures.toSet, + strategy.impurity, strategy.quantileCalculationStrategy) + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala index ccac1031fd9d9..170e43e222083 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala @@ -18,7 +18,6 @@ package org.apache.spark.mllib.tree.impl import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.tree.configuration.Strategy import org.apache.spark.mllib.tree.model.Bin import org.apache.spark.rdd.RDD @@ -48,50 +47,35 @@ private[tree] object TreePoint { * Convert an input dataset into its TreePoint representation, * binning feature values in preparation for DecisionTree training. * @param input Input dataset. - * @param strategy DecisionTree training info, used for dataset metadata. * @param bins Bins for features, of size (numFeatures, numBins). + * @param metadata Learning and dataset metadata * @return TreePoint dataset representation */ def convertToTreeRDD( input: RDD[LabeledPoint], - strategy: Strategy, - bins: Array[Array[Bin]]): RDD[TreePoint] = { + bins: Array[Array[Bin]], + metadata: DecisionTreeMetadata): RDD[TreePoint] = { input.map { x => - TreePoint.labeledPointToTreePoint(x, strategy.isMulticlassClassification, bins, - strategy.categoricalFeaturesInfo) + TreePoint.labeledPointToTreePoint(x, bins, metadata) } } /** * Convert one LabeledPoint into its TreePoint representation. * @param bins Bins for features, of size (numFeatures, numBins). - * @param categoricalFeaturesInfo Map over categorical features: feature index --> feature arity */ private def labeledPointToTreePoint( labeledPoint: LabeledPoint, - isMulticlassClassification: Boolean, bins: Array[Array[Bin]], - categoricalFeaturesInfo: Map[Int, Int]): TreePoint = { + metadata: DecisionTreeMetadata): TreePoint = { val numFeatures = labeledPoint.features.size val numBins = bins(0).size val arr = new Array[Int](numFeatures) var featureIndex = 0 while (featureIndex < numFeatures) { - val featureInfo = categoricalFeaturesInfo.get(featureIndex) - val isFeatureContinuous = featureInfo.isEmpty - if (isFeatureContinuous) { - arr(featureIndex) = findBin(featureIndex, labeledPoint, isFeatureContinuous, false, - bins, categoricalFeaturesInfo) - } else { - val featureCategories = featureInfo.get - val isSpaceSufficientForAllCategoricalSplits - = numBins > math.pow(2, featureCategories.toInt - 1) - 1 - val isUnorderedFeature = - isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits - arr(featureIndex) = findBin(featureIndex, labeledPoint, isFeatureContinuous, - isUnorderedFeature, bins, categoricalFeaturesInfo) - } + arr(featureIndex) = findBin(featureIndex, labeledPoint, metadata.isContinuous(featureIndex), + metadata.isUnordered(featureIndex), bins, metadata.featureArity) featureIndex += 1 } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala index c89c1e371a40e..af35d88f713e5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala @@ -20,15 +20,25 @@ package org.apache.spark.mllib.tree.model import org.apache.spark.mllib.tree.configuration.FeatureType._ /** - * Used for "binning" the features bins for faster best split calculation. For a continuous - * feature, a bin is determined by a low and a high "split". For a categorical feature, - * the a bin is determined using a single label value (category). + * Used for "binning" the features bins for faster best split calculation. + * + * For a continuous feature, the bin is determined by a low and a high split, + * where an example with featureValue falls into the bin s.t. + * lowSplit.threshold < featureValue <= highSplit.threshold. + * + * For ordered categorical features, there is a 1-1-1 correspondence between + * bins, splits, and feature values. The bin is determined by category/feature value. + * However, the bins are not necessarily ordered by feature value; + * they are ordered using impurity. + * For unordered categorical features, there is a 1-1 correspondence between bins, splits, + * where bins and splits correspond to subsets of feature values (in highSplit.categories). + * * @param lowSplit signifying the lower threshold for the continuous feature to be * accepted in the bin * @param highSplit signifying the upper threshold for the continuous feature to be * accepted in the bin * @param featureType type of feature -- categorical or continuous - * @param category categorical label value accepted in the bin for binary classification + * @param category categorical label value accepted in the bin for ordered features */ private[tree] case class Bin(lowSplit: Split, highSplit: Split, featureType: FeatureType, category: Double) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala index 3d3406b5d5f22..0594fd0749d21 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala @@ -39,7 +39,7 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable * @return Double prediction from the trained model */ def predict(features: Vector): Double = { - topNode.predictIfLeaf(features) + topNode.predict(features) } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Filter.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Filter.scala deleted file mode 100644 index 2deaf4ae8dcab..0000000000000 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Filter.scala +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.mllib.tree.model - -/** - * Filter specifying a split and type of comparison to be applied on features - * @param split split specifying the feature index, type and threshold - * @param comparison integer specifying <,=,> - */ -private[tree] case class Filter(split: Split, comparison: Int) { - // Comparison -1,0,1 signifies <.=,> - override def toString = " split = " + split + "comparison = " + comparison -} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala index 944f11c2c2e4f..0eee6262781c1 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala @@ -69,24 +69,24 @@ class Node ( /** * predict value if node is not leaf - * @param feature feature value + * @param features feature value * @return predicted value */ - def predictIfLeaf(feature: Vector) : Double = { + def predict(features: Vector) : Double = { if (isLeaf) { predict } else{ if (split.get.featureType == Continuous) { - if (feature(split.get.feature) <= split.get.threshold) { - leftNode.get.predictIfLeaf(feature) + if (features(split.get.feature) <= split.get.threshold) { + leftNode.get.predict(features) } else { - rightNode.get.predictIfLeaf(feature) + rightNode.get.predict(features) } } else { - if (split.get.categories.contains(feature(split.get.feature))) { - leftNode.get.predictIfLeaf(feature) + if (split.get.categories.contains(features(split.get.feature))) { + leftNode.get.predict(features) } else { - rightNode.get.predictIfLeaf(feature) + rightNode.get.predict(features) } } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala index d7ffd386c05ee..50fb48b40de3d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala @@ -24,9 +24,10 @@ import org.apache.spark.mllib.tree.configuration.FeatureType.FeatureType * :: DeveloperApi :: * Split applied to a feature * @param feature feature index - * @param threshold threshold for continuous feature + * @param threshold Threshold for continuous feature. + * Split left if feature <= threshold, else right. * @param featureType type of feature -- categorical or continuous - * @param categories accepted values for categorical variables + * @param categories Split left if categorical feature value is in this set, else right. */ @DeveloperApi case class Split( diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index a5c49a38dc08f..2f36fd907772c 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -23,10 +23,10 @@ import org.scalatest.FunSuite import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.FeatureType._ -import org.apache.spark.mllib.tree.configuration.{FeatureType, Strategy} -import org.apache.spark.mllib.tree.impl.TreePoint +import org.apache.spark.mllib.tree.configuration.Strategy +import org.apache.spark.mllib.tree.impl.{DecisionTreeMetadata, TreePoint} import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance} -import org.apache.spark.mllib.tree.model.{DecisionTreeModel, Filter, Split} +import org.apache.spark.mllib.tree.model.{DecisionTreeModel, Node} import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.LocalSparkContext import org.apache.spark.mllib.regression.LabeledPoint @@ -64,7 +64,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(arr.length === 1000) val rdd = sc.parallelize(arr) val strategy = new Strategy(Classification, Gini, 3, 2, 100) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) assert(splits.length === 2) assert(bins.length === 2) assert(splits(0).length === 99) @@ -82,7 +83,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { numClassesForClassification = 2, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 2, 1-> 2)) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) assert(splits.length === 2) assert(bins.length === 2) assert(splits(0).length === 99) @@ -162,7 +164,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { numClassesForClassification = 2, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) // Check splits. @@ -279,7 +282,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { numClassesForClassification = 100, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) // Expecting 2^2 - 1 = 3 bins/splits assert(splits(0)(0).feature === 0) @@ -373,7 +377,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { numClassesForClassification = 100, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 10, 1-> 10)) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) // 2^10 - 1 > 100, so categorical variables will be ordered @@ -428,10 +433,11 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { maxDepth = 2, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) - val treeInput = TreePoint.convertToTreeRDD(rdd, strategy, bins) - val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(7), strategy, 0, - Array[List[Filter]](), splits, bins, 10) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) + val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) + val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(7), metadata, 0, + new Array[Node](0), splits, bins, 10) val split = bestSplits(0)._1 assert(split.categories.length === 1) @@ -456,10 +462,11 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { maxDepth = 2, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) - val (splits, bins) = DecisionTree.findSplitsBins(rdd,strategy) - val treeInput = TreePoint.convertToTreeRDD(rdd, strategy, bins) - val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(7), strategy, 0, - Array[List[Filter]](), splits, bins, 10) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) + val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) + val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(7), metadata, 0, + new Array[Node](0), splits, bins, 10) val split = bestSplits(0)._1 assert(split.categories.length === 1) @@ -495,7 +502,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(arr.length === 1000) val rdd = sc.parallelize(arr) val strategy = new Strategy(Classification, Gini, 3, 2, 100) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) assert(splits.length === 2) assert(splits(0).length === 99) assert(bins.length === 2) @@ -503,9 +511,9 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(splits(0).length === 99) assert(bins(0).length === 100) - val treeInput = TreePoint.convertToTreeRDD(rdd, strategy, bins) - val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(7), strategy, 0, - Array[List[Filter]](), splits, bins, 10) + val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) + val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(7), metadata, 0, + new Array[Node](0), splits, bins, 10) assert(bestSplits.length === 1) assert(bestSplits(0)._1.feature === 0) assert(bestSplits(0)._2.gain === 0) @@ -518,7 +526,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(arr.length === 1000) val rdd = sc.parallelize(arr) val strategy = new Strategy(Classification, Gini, 3, 2, 100) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) assert(splits.length === 2) assert(splits(0).length === 99) assert(bins.length === 2) @@ -526,9 +535,9 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(splits(0).length === 99) assert(bins(0).length === 100) - val treeInput = TreePoint.convertToTreeRDD(rdd, strategy, bins) - val bestSplits = DecisionTree.findBestSplits(treeInput, Array(0.0), strategy, 0, - Array[List[Filter]](), splits, bins, 10) + val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) + val bestSplits = DecisionTree.findBestSplits(treeInput, Array(0.0), metadata, 0, + new Array[Node](0), splits, bins, 10) assert(bestSplits.length === 1) assert(bestSplits(0)._1.feature === 0) assert(bestSplits(0)._2.gain === 0) @@ -542,7 +551,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(arr.length === 1000) val rdd = sc.parallelize(arr) val strategy = new Strategy(Classification, Entropy, 3, 2, 100) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) assert(splits.length === 2) assert(splits(0).length === 99) assert(bins.length === 2) @@ -550,9 +560,9 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(splits(0).length === 99) assert(bins(0).length === 100) - val treeInput = TreePoint.convertToTreeRDD(rdd, strategy, bins) - val bestSplits = DecisionTree.findBestSplits(treeInput, Array(0.0), strategy, 0, - Array[List[Filter]](), splits, bins, 10) + val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) + val bestSplits = DecisionTree.findBestSplits(treeInput, Array(0.0), metadata, 0, + new Array[Node](0), splits, bins, 10) assert(bestSplits.length === 1) assert(bestSplits(0)._1.feature === 0) assert(bestSplits(0)._2.gain === 0) @@ -566,7 +576,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(arr.length === 1000) val rdd = sc.parallelize(arr) val strategy = new Strategy(Classification, Entropy, 3, 2, 100) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) assert(splits.length === 2) assert(splits(0).length === 99) assert(bins.length === 2) @@ -574,9 +585,9 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(splits(0).length === 99) assert(bins(0).length === 100) - val treeInput = TreePoint.convertToTreeRDD(rdd, strategy, bins) - val bestSplits = DecisionTree.findBestSplits(treeInput, Array(0.0), strategy, 0, - Array[List[Filter]](), splits, bins, 10) + val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) + val bestSplits = DecisionTree.findBestSplits(treeInput, Array(0.0), metadata, 0, + new Array[Node](0), splits, bins, 10) assert(bestSplits.length === 1) assert(bestSplits(0)._1.feature === 0) assert(bestSplits(0)._2.gain === 0) @@ -590,7 +601,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(arr.length === 1000) val rdd = sc.parallelize(arr) val strategy = new Strategy(Classification, Entropy, 3, 2, 100) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) assert(splits.length === 2) assert(splits(0).length === 99) assert(bins.length === 2) @@ -598,14 +610,19 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(splits(0).length === 99) assert(bins(0).length === 100) - val leftFilter = Filter(new Split(0, 400, FeatureType.Continuous, List()), -1) - val rightFilter = Filter(new Split(0, 400, FeatureType.Continuous, List()) ,1) - val filters = Array[List[Filter]](List(), List(leftFilter), List(rightFilter)) + // Train a 1-node model + val strategyOneNode = new Strategy(Classification, Entropy, 1, 2, 100) + val modelOneNode = DecisionTree.train(rdd, strategyOneNode) + val nodes: Array[Node] = new Array[Node](7) + nodes(0) = modelOneNode.topNode + nodes(0).leftNode = None + nodes(0).rightNode = None + val parentImpurities = Array(0.5, 0.5, 0.5) // Single group second level tree construction. - val treeInput = TreePoint.convertToTreeRDD(rdd, strategy, bins) - val bestSplits = DecisionTree.findBestSplits(treeInput, parentImpurities, strategy, 1, filters, + val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) + val bestSplits = DecisionTree.findBestSplits(treeInput, parentImpurities, metadata, 1, nodes, splits, bins, 10) assert(bestSplits.length === 2) assert(bestSplits(0)._2.gain > 0) @@ -613,8 +630,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { // maxLevelForSingleGroup parameter is set to 0 to force splitting into groups for second // level tree construction. - val bestSplitsWithGroups = DecisionTree.findBestSplits(treeInput, parentImpurities, strategy, 1, - filters, splits, bins, 0) + val bestSplitsWithGroups = DecisionTree.findBestSplits(treeInput, parentImpurities, metadata, 1, + nodes, splits, bins, 0) assert(bestSplitsWithGroups.length === 2) assert(bestSplitsWithGroups(0)._2.gain > 0) assert(bestSplitsWithGroups(1)._2.gain > 0) @@ -629,19 +646,19 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bestSplits(i)._2.rightImpurity === bestSplitsWithGroups(i)._2.rightImpurity) assert(bestSplits(i)._2.predict === bestSplitsWithGroups(i)._2.predict) } - } test("stump with categorical variables for multiclass classification") { val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlass() - val input = sc.parallelize(arr) + val rdd = sc.parallelize(arr) val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) assert(strategy.isMulticlassClassification) - val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) - val treeInput = TreePoint.convertToTreeRDD(input, strategy, bins) - val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), strategy, 0, - Array[List[Filter]](), splits, bins, 10) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) + val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) + val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), metadata, 0, + new Array[Node](0), splits, bins, 10) assert(bestSplits.length === 1) val bestSplit = bestSplits(0)._1 @@ -657,11 +674,11 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { arr(1) = new LabeledPoint(1.0, Vectors.dense(1.0)) arr(2) = new LabeledPoint(1.0, Vectors.dense(2.0)) arr(3) = new LabeledPoint(1.0, Vectors.dense(3.0)) - val input = sc.parallelize(arr) + val rdd = sc.parallelize(arr) val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, numClassesForClassification = 2) - val model = DecisionTree.train(input, strategy) + val model = DecisionTree.train(rdd, strategy) validateClassifier(model, arr, 1.0) assert(model.numNodes === 3) assert(model.depth === 1) @@ -688,20 +705,22 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { test("stump with categorical variables for multiclass classification, with just enough bins") { val maxBins = math.pow(2, 3 - 1).toInt // just enough bins to allow unordered features val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlass() - val input = sc.parallelize(arr) + val rdd = sc.parallelize(arr) val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, - numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) + numClassesForClassification = 3, maxBins = maxBins, + categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) assert(strategy.isMulticlassClassification) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) - val model = DecisionTree.train(input, strategy) + val model = DecisionTree.train(rdd, strategy) validateClassifier(model, arr, 1.0) assert(model.numNodes === 3) assert(model.depth === 1) - val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) - val treeInput = TreePoint.convertToTreeRDD(input, strategy, bins) - val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), strategy, 0, - Array[List[Filter]](), splits, bins, 10) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) + val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) + val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), metadata, 0, + new Array[Node](0), splits, bins, 10) assert(bestSplits.length === 1) val bestSplit = bestSplits(0)._1 @@ -716,18 +735,19 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { test("stump with continuous variables for multiclass classification") { val arr = DecisionTreeSuite.generateContinuousDataPointsForMulticlass() - val input = sc.parallelize(arr) + val rdd = sc.parallelize(arr) val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, numClassesForClassification = 3) assert(strategy.isMulticlassClassification) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) - val model = DecisionTree.train(input, strategy) + val model = DecisionTree.train(rdd, strategy) validateClassifier(model, arr, 0.9) - val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) - val treeInput = TreePoint.convertToTreeRDD(input, strategy, bins) - val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), strategy, 0, - Array[List[Filter]](), splits, bins, 10) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) + val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) + val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), metadata, 0, + new Array[Node](0), splits, bins, 10) assert(bestSplits.length === 1) val bestSplit = bestSplits(0)._1 @@ -741,18 +761,19 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { test("stump with continuous + categorical variables for multiclass classification") { val arr = DecisionTreeSuite.generateContinuousDataPointsForMulticlass() - val input = sc.parallelize(arr) + val rdd = sc.parallelize(arr) val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 3)) assert(strategy.isMulticlassClassification) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) - val model = DecisionTree.train(input, strategy) + val model = DecisionTree.train(rdd, strategy) validateClassifier(model, arr, 0.9) - val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) - val treeInput = TreePoint.convertToTreeRDD(input, strategy, bins) - val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), strategy, 0, - Array[List[Filter]](), splits, bins, 10) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) + val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) + val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), metadata, 0, + new Array[Node](0), splits, bins, 10) assert(bestSplits.length === 1) val bestSplit = bestSplits(0)._1 @@ -765,14 +786,16 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { test("stump with categorical variables for ordered multiclass classification") { val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures() - val input = sc.parallelize(arr) + val rdd = sc.parallelize(arr) val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 10, 1 -> 10)) assert(strategy.isMulticlassClassification) - val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) - val treeInput = TreePoint.convertToTreeRDD(input, strategy, bins) - val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), strategy, 0, - Array[List[Filter]](), splits, bins, 10) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + + val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) + val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) + val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), metadata, 0, + new Array[Node](0), splits, bins, 10) assert(bestSplits.length === 1) val bestSplit = bestSplits(0)._1 From d411f4190252546b0ea99c1934efd5e5f84be50c Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Sun, 17 Aug 2014 15:48:39 -0700 Subject: [PATCH 165/231] SPARK-2881: Upgrade to Snappy 1.0.5.3 to avoid SPARK-2881. This version of Snappy was released with a backported fix specifically for Spark. This fixes an issue where names collide in the snappy .so file when users are submitting jobs as different users on the same cluster. Author: Patrick Wendell Closes #1999 from pwendell/snappy-upgrade and squashes the following commits: 38974ff [Patrick Wendell] SPARK-2881: Upgrade to Snappy 1.0.5.3 to avoid SPARK-2881. --- pom.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pom.xml b/pom.xml index c87f776bda659..9e5217e294681 100644 --- a/pom.xml +++ b/pom.xml @@ -315,7 +315,7 @@ org.xerial.snappy snappy-java - 1.0.5 + 1.0.5.3 net.jpountz.lz4 From c6a0091ea401e0bec58d7607eb42be89cc090868 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Sun, 17 Aug 2014 18:10:45 -0700 Subject: [PATCH 166/231] Revert "[SPARK-2970] [SQL] spark-sql script ends with IOException when EventLogging is enabled" Revert #1891 due to issues with hadoop 1 compatibility. Author: Michael Armbrust Closes #2007 from marmbrus/revert1891 and squashes the following commits: 68706c0 [Michael Armbrust] Revert "[SPARK-2970] [SQL] spark-sql script ends with IOException when EventLogging is enabled" (cherry picked from commit 5ecb08ea063166564178885b7515abef0d76eecb) Signed-off-by: Michael Armbrust --- .../sql/hive/thriftserver/SparkSQLCLIDriver.scala | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala index c16a7d3661c66..b092f42372171 100755 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala @@ -26,8 +26,6 @@ import jline.{ConsoleReader, History} import org.apache.commons.lang.StringUtils import org.apache.commons.logging.LogFactory import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.FileSystem -import org.apache.hadoop.util.ShutdownHookManager import org.apache.hadoop.hive.cli.{CliDriver, CliSessionState, OptionsProcessor} import org.apache.hadoop.hive.common.LogUtils.LogInitializationException import org.apache.hadoop.hive.common.{HiveInterruptCallback, HiveInterruptUtils, LogUtils} @@ -118,17 +116,13 @@ private[hive] object SparkSQLCLIDriver { SessionState.start(sessionState) // Clean up after we exit - /** - * This should be executed before shutdown hook of - * FileSystem to avoid race condition of FileSystem operation - */ - ShutdownHookManager.get.addShutdownHook( + Runtime.getRuntime.addShutdownHook( new Thread() { override def run() { SparkSQLEnv.stop() } } - , FileSystem.SHUTDOWN_HOOK_PRIORITY - 1) + ) // "-h" option has been passed, so connect to Hive thrift server. if (sessionState.getHost != null) { From 4f776dfab726f54c948a83a7157b958903c15ecf Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Sun, 17 Aug 2014 19:00:38 -0700 Subject: [PATCH 167/231] [SQL] Improve debug logging and toStrings. Author: Michael Armbrust Closes #2004 from marmbrus/codgenDebugging and squashes the following commits: b7a7e41 [Michael Armbrust] Improve debug logging and toStrings. (cherry picked from commit bfa09b01d7eddc572cd22ca2e418a735b4ccc826) Signed-off-by: Michael Armbrust --- .../expressions/codegen/CodeGenerator.scala | 21 +++++++++++++++++-- .../catalyst/expressions/nullFunctions.scala | 2 ++ 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 5b398695bf560..de2d67ce82ff1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -78,7 +78,12 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin .build( new CacheLoader[InType, OutType]() { override def load(in: InType): OutType = globalLock.synchronized { - create(in) + val startTime = System.nanoTime() + val result = create(in) + val endTime = System.nanoTime() + def timeMs = (endTime - startTime).toDouble / 1000000 + logInfo(s"Code generated expression $in in $timeMs ms") + result } }) @@ -413,7 +418,19 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin """.children } - EvaluatedExpression(code, nullTerm, primitiveTerm, objectTerm) + // Only inject debugging code if debugging is turned on. + val debugCode = + if (log.isDebugEnabled) { + val localLogger = log + val localLoggerTree = reify { localLogger } + q""" + $localLoggerTree.debug(${e.toString} + ": " + (if($nullTerm) "null" else $primitiveTerm)) + """ :: Nil + } else { + Nil + } + + EvaluatedExpression(code ++ debugCode, nullTerm, primitiveTerm, objectTerm) } protected def getColumn(inputRow: TermName, dataType: DataType, ordinal: Int) = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala index ce6d99c911ab3..e88c5d4fa178a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala @@ -60,6 +60,8 @@ case class IsNull(child: Expression) extends Predicate with trees.UnaryNode[Expr override def eval(input: Row): Any = { child.eval(input) == null } + + override def toString = s"IS NULL $child" } case class IsNotNull(child: Expression) extends Predicate with trees.UnaryNode[Expression] { From 826356725ffb3189180f7879d3f9c449924785f3 Mon Sep 17 00:00:00 2001 From: Chris Fregly Date: Sun, 17 Aug 2014 19:33:15 -0700 Subject: [PATCH 168/231] [SPARK-1981] updated streaming-kinesis.md fixed markup, separated out sections more-clearly, more thorough explanations Author: Chris Fregly Closes #1757 from cfregly/master and squashes the following commits: 9b1c71a [Chris Fregly] better explained why spark checkpoints are disabled in the example (due to no stateful operations being used) 0f37061 [Chris Fregly] SPARK-1981: (Kinesis streaming support) updated streaming-kinesis.md 862df67 [Chris Fregly] Merge remote-tracking branch 'upstream/master' 8e1ae2e [Chris Fregly] Merge remote-tracking branch 'upstream/master' 4774581 [Chris Fregly] updated docs, renamed retry to retryRandom to be more clear, removed retries around store() method 0393795 [Chris Fregly] moved Kinesis examples out of examples/ and back into extras/kinesis-asl 691a6be [Chris Fregly] fixed tests and formatting, fixed a bug with JavaKinesisWordCount during union of streams 0e1c67b [Chris Fregly] Merge remote-tracking branch 'upstream/master' 74e5c7c [Chris Fregly] updated per TD's feedback. simplified examples, updated docs e33cbeb [Chris Fregly] Merge remote-tracking branch 'upstream/master' bf614e9 [Chris Fregly] per matei's feedback: moved the kinesis examples into the examples/ dir d17ca6d [Chris Fregly] per TD's feedback: updated docs, simplified the KinesisUtils api 912640c [Chris Fregly] changed the foundKinesis class to be a publically-avail class db3eefd [Chris Fregly] Merge remote-tracking branch 'upstream/master' 21de67f [Chris Fregly] Merge remote-tracking branch 'upstream/master' 6c39561 [Chris Fregly] parameterized the versions of the aws java sdk and kinesis client 338997e [Chris Fregly] improve build docs for kinesis 828f8ae [Chris Fregly] more cleanup e7c8978 [Chris Fregly] Merge remote-tracking branch 'upstream/master' cd68c0d [Chris Fregly] fixed typos and backward compatibility d18e680 [Chris Fregly] Merge remote-tracking branch 'upstream/master' b3b0ff1 [Chris Fregly] [SPARK-1981] Add AWS Kinesis streaming support (cherry picked from commit 99243288b049f4a4fb4ba0505ea2310be5eb4bd2) Signed-off-by: Tathagata Das --- docs/streaming-kinesis.md | 97 ++++++++++++++++++++------------------- 1 file changed, 49 insertions(+), 48 deletions(-) diff --git a/docs/streaming-kinesis.md b/docs/streaming-kinesis.md index 801c905c88df8..16ad3222105a2 100644 --- a/docs/streaming-kinesis.md +++ b/docs/streaming-kinesis.md @@ -3,56 +3,57 @@ layout: global title: Spark Streaming Kinesis Receiver --- -### Kinesis -Build notes: -
  • Spark supports a Kinesis Streaming Receiver which is not included in the default build due to licensing restrictions.
  • -
  • _**Note that by embedding this library you will include [ASL](https://aws.amazon.com/asl/)-licensed code in your Spark package**_.
  • -
  • The Spark Kinesis Streaming Receiver source code, examples, tests, and artifacts live in $SPARK_HOME/extras/kinesis-asl.
  • -
  • To build with Kinesis, you must run the maven or sbt builds with -Pkinesis-asl`.
  • -
  • Applications will need to link to the 'spark-streaming-kinesis-asl` artifact.
  • +## Kinesis +###Design +
  • The KinesisReceiver uses the Kinesis Client Library (KCL) provided by Amazon under the Amazon Software License.
  • +
  • The KCL builds on top of the Apache 2.0 licensed AWS Java SDK and provides load-balancing, fault-tolerance, checkpointing through the concept of Workers, Checkpoints, and Shard Leases.
  • +
  • The KCL uses DynamoDB to maintain all state. A DynamoDB table is created in the us-east-1 region (regardless of Kinesis stream region) during KCL initialization for each Kinesis application name.
  • +
  • A single KinesisReceiver can process many shards of a stream by spinning up multiple KinesisRecordProcessor threads.
  • +
  • You never need more KinesisReceivers than the number of shards in your stream as each will spin up at least one KinesisRecordProcessor thread.
  • +
  • Horizontal scaling is achieved by autoscaling additional KinesisReceiver (separate processes) or spinning up new KinesisRecordProcessor threads within each KinesisReceiver - up to the number of current shards for a given stream, of course. Don't forget to autoscale back down!
  • -Kinesis examples notes: -
  • To build the Kinesis examples, you must run the maven or sbt builds with -Pkinesis-asl`.
  • -
  • These examples automatically determine the number of local threads and KinesisReceivers to spin up based on the number of shards for the stream.
  • -
  • KinesisWordCountProducerASL will generate random data to put onto the Kinesis stream for testing.
  • -
  • Checkpointing is disabled (no checkpoint dir is set). The examples as written will not recover from a driver failure.
  • +### Build +
  • Spark supports a Streaming KinesisReceiver, but it is not included in the default build due to Amazon Software Licensing (ASL) restrictions.
  • +
  • To build with the Kinesis Streaming Receiver and supporting ASL-licensed code, you must run the maven or sbt builds with the **-Pkinesis-asl** profile.
  • +
  • All KinesisReceiver-related code, examples, tests, and artifacts live in **$SPARK_HOME/extras/kinesis-asl/**.
  • +
  • Kinesis-based Spark Applications will need to link to the **spark-streaming-kinesis-asl** artifact that is built when **-Pkinesis-asl** is specified.
  • +
  • _**Note that by linking to this library, you will include [ASL](https://aws.amazon.com/asl/)-licensed code in your Spark package**_.
  • -Deployment and runtime notes: -
  • A single KinesisReceiver can process many shards of a stream.
  • -
  • Each shard of a stream is processed by one or more KinesisReceiver's managed by the Kinesis Client Library (KCL) Worker.
  • -
  • You never need more KinesisReceivers than the number of shards in your stream.
  • -
  • You can horizontally scale the receiving by creating more KinesisReceiver/DStreams (up to the number of shards for a given stream)
  • -
  • The Kinesis libraries must be present on all worker nodes, as they will need access to the Kinesis Client Library.
  • -
  • This code uses the DefaultAWSCredentialsProviderChain and searches for credentials in the following order of precedence:
    - 1) Environment Variables - AWS_ACCESS_KEY_ID and AWS_SECRET_KEY
    - 2) Java System Properties - aws.accessKeyId and aws.secretKey
    - 3) Credential profiles file - default location (~/.aws/credentials) shared by all AWS SDKs
    - 4) Instance profile credentials - delivered through the Amazon EC2 metadata service
    -
  • -
  • You need to setup a Kinesis stream with 1 or more shards per the following:
    - http://docs.aws.amazon.com/kinesis/latest/dev/step-one-create-stream.html
  • -
  • Valid Kinesis endpoint urls can be found here: Valid endpoint urls: http://docs.aws.amazon.com/general/latest/gr/rande.html#ak_region
  • -
  • When you first start up the KinesisReceiver, the Kinesis Client Library (KCL) needs ~30s to establish connectivity with the AWS Kinesis service, -retrieve any checkpoint data, and negotiate with other KCL's reading from the same stream.
  • -
  • Be careful when changing the app name. Kinesis maintains a mapping table in DynamoDB based on this app name (http://docs.aws.amazon.com/kinesis/latest/dev/kinesis-record-processor-implementation-app.html#kinesis-record-processor-initialization). -Changing the app name could lead to Kinesis errors as only 1 logical application can process a stream. In order to start fresh, -it's always best to delete the DynamoDB table that matches your app name. This DynamoDB table lives in us-east-1 regardless of the Kinesis endpoint URL.
  • +###Example +
  • To build the Kinesis example, you must run the maven or sbt builds with the **-Pkinesis-asl** profile.
  • +
  • You need to setup a Kinesis stream at one of the valid Kinesis endpoints with 1 or more shards per the following: http://docs.aws.amazon.com/kinesis/latest/dev/step-one-create-stream.html
  • +
  • Valid Kinesis endpoints can be found here: http://docs.aws.amazon.com/general/latest/gr/rande.html#ak_region
  • +
  • When running **locally**, the example automatically determines the number of threads and KinesisReceivers to spin up based on the number of shards configured for the stream. Therefore, **local[n]** is not needed when starting the example as with other streaming examples.
  • +
  • While this example could use a single KinesisReceiver which spins up multiple KinesisRecordProcessor threads to process multiple shards, I wanted to demonstrate unioning multiple KinesisReceivers as a single DStream. (It's a bit confusing in local mode.)
  • +
  • **KinesisWordCountProducerASL** is provided to generate random records into the Kinesis stream for testing.
  • +
  • The example has been configured to immediately replicate incoming stream data to another node by using (StorageLevel.MEMORY_AND_DISK_2) +
  • Spark checkpointing is disabled because the example does not use any stateful or window-based DStream operations such as updateStateByKey and reduceByWindow. If those operations are introduced, you would need to enable checkpointing or risk losing data in the case of a failure.
  • +
  • Kinesis checkpointing is enabled. This means that the example will recover from a Kinesis failure.
  • +
  • The example uses InitialPositionInStream.LATEST strategy to pull from the latest tip of the stream if no Kinesis checkpoint info exists.
  • +
  • In our example, **KinesisWordCount** is the Kinesis application name for both the Scala and Java versions. The use of this application name is described next.
  • -Failure recovery notes: -
  • The combination of Spark Streaming and Kinesis creates 3 different checkpoints as follows:
    - 1) RDD data checkpoint (Spark Streaming) - frequency is configurable with DStream.checkpoint(Duration)
    - 2) RDD metadata checkpoint (Spark Streaming) - frequency is every DStream batch
    - 3) Kinesis checkpointing (Kinesis) - frequency is controlled by the developer calling ICheckpointer.checkpoint() directly
    +###Deployment and Runtime +
  • A Kinesis application name must be unique for a given account and region.
  • +
  • A DynamoDB table and CloudWatch namespace are created during KCL initialization using this Kinesis application name. http://docs.aws.amazon.com/kinesis/latest/dev/kinesis-record-processor-implementation-app.html#kinesis-record-processor-initialization
  • +
  • This DynamoDB table lives in the us-east-1 region regardless of the Kinesis endpoint URL.
  • +
  • Changing the app name or stream name could lead to Kinesis errors as only a single logical application can process a single stream.
  • +
  • If you are seeing errors after changing the app name or stream name, it may be necessary to manually delete the DynamoDB table and start from scratch.
  • +
  • The Kinesis libraries must be present on all worker nodes, as they will need access to the KCL.
  • +
  • The KinesisReceiver uses the DefaultAWSCredentialsProviderChain for AWS credentials which searches for credentials in the following order of precedence:
    +1) Environment Variables - AWS_ACCESS_KEY_ID and AWS_SECRET_KEY
    +2) Java System Properties - aws.accessKeyId and aws.secretKey
    +3) Credential profiles file - default location (~/.aws/credentials) shared by all AWS SDKs
    +4) Instance profile credentials - delivered through the Amazon EC2 metadata service
  • -
  • Checkpointing too frequently will cause excess load on the AWS checkpoint storage layer and may lead to AWS throttling
  • -
  • Upon startup, a KinesisReceiver will begin processing records with sequence numbers greater than the last checkpoint sequence number recorded per shard.
  • -
  • If no checkpoint info exists, the worker will start either from the oldest record available (InitialPositionInStream.TRIM_HORIZON) -or from the tip/latest (InitialPostitionInStream.LATEST). This is configurable.
  • -
  • When pulling from the stream tip (InitialPositionInStream.LATEST), only new stream data will be picked up after the KinesisReceiver starts.
  • -
  • InitialPositionInStream.LATEST could lead to missed records if data is added to the stream while no KinesisReceivers are running.
  • -
  • In production, you'll want to switch to InitialPositionInStream.TRIM_HORIZON which will read up to 24 hours (Kinesis limit) of previous stream data -depending on the checkpoint frequency.
  • -
  • InitialPositionInStream.TRIM_HORIZON may lead to duplicate processing of records depending on the checkpoint frequency.
  • + +###Fault-Tolerance +
  • The combination of Spark Streaming and Kinesis creates 2 different checkpoints that may occur at different intervals.
  • +
  • Checkpointing too frequently against Kinesis will cause excess load on the AWS checkpoint storage layer and may lead to AWS throttling. The provided example handles this throttling with a random backoff retry strategy.
  • +
  • Upon startup, a KinesisReceiver will begin processing records with sequence numbers greater than the last Kinesis checkpoint sequence number recorded per shard (stored in the DynamoDB table).
  • +
  • If no Kinesis checkpoint info exists, the KinesisReceiver will start either from the oldest record available (InitialPositionInStream.TRIM_HORIZON) or from the latest tip (InitialPostitionInStream.LATEST). This is configurable.
  • +
  • InitialPositionInStream.LATEST could lead to missed records if data is added to the stream while no KinesisReceivers are running (and no checkpoint info is being stored.)
  • +
  • In production, you'll want to switch to InitialPositionInStream.TRIM_HORIZON which will read up to 24 hours (Kinesis limit) of previous stream data.
  • +
  • InitialPositionInStream.TRIM_HORIZON may lead to duplicate processing of records where the impact is dependent on checkpoint frequency.
  • Record processing should be idempotent when possible.
  • -
  • Failed or latent KinesisReceivers will be detected and automatically shutdown/load-balanced by the KCL.
  • -
  • If possible, explicitly shutdown the worker if a failure occurs in order to trigger the final checkpoint.
  • +
  • A failed or latent KinesisRecordProcessor within the KinesisReceiver will be detected and automatically restarted by the KCL.
  • +
  • If possible, the KinesisReceiver should be shutdown cleanly in order to trigger a final checkpoint of all KinesisRecordProcessors to avoid duplicate record processing.
  • \ No newline at end of file From 8438daf2c2a04e48465fc2681d142ca5a6dec747 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Sun, 17 Aug 2014 20:53:18 -0700 Subject: [PATCH 169/231] [SPARK-3087][MLLIB] fix col indexing bug in chi-square and add a check for number of distinct values There is a bug determining the column index. dorx Author: Xiangrui Meng Closes #1997 from mengxr/chisq-index and squashes the following commits: 8fc2ab2 [Xiangrui Meng] fix col indexing bug and add a check for number of distinct values (cherry picked from commit c77f40668fbb5b8bca9a9b25c039895cb7a4a80c) Signed-off-by: Xiangrui Meng --- .../apache/spark/mllib/stat/Statistics.scala | 2 +- .../spark/mllib/stat/test/ChiSqTest.scala | 37 +++++++++++++++---- .../mllib/stat/HypothesisTestSuite.scala | 37 ++++++++++++++----- 3 files changed, 59 insertions(+), 17 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala index 3cf1028fbc725..3cf4e807b4cf7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala @@ -155,7 +155,7 @@ object Statistics { * :: Experimental :: * Conduct Pearson's independence test for every feature against the label across the input RDD. * For each feature, the (feature, label) pairs are converted into a contingency matrix for which - * the chi-squared statistic is computed. + * the chi-squared statistic is computed. All label and feature values must be categorical. * * @param data an `RDD[LabeledPoint]` containing the labeled dataset with categorical features. * Real-valued features will be treated as categorical for each distinct value. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala index 215de95db5113..0089419c2c5d4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala @@ -20,11 +20,13 @@ package org.apache.spark.mllib.stat.test import breeze.linalg.{DenseMatrix => BDM} import cern.jet.stat.Probability.chiSquareComplemented -import org.apache.spark.Logging +import org.apache.spark.{SparkException, Logging} import org.apache.spark.mllib.linalg.{Matrices, Matrix, Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.rdd.RDD +import scala.collection.mutable + /** * Conduct the chi-squared test for the input RDDs using the specified method. * Goodness-of-fit test is conducted on two `Vectors`, whereas test of independence is conducted @@ -75,21 +77,42 @@ private[stat] object ChiSqTest extends Logging { */ def chiSquaredFeatures(data: RDD[LabeledPoint], methodName: String = PEARSON.name): Array[ChiSqTestResult] = { + val maxCategories = 10000 val numCols = data.first().features.size val results = new Array[ChiSqTestResult](numCols) var labels: Map[Double, Int] = null - // At most 100 columns at a time - val batchSize = 100 + // at most 1000 columns at a time + val batchSize = 1000 var batch = 0 while (batch * batchSize < numCols) { // The following block of code can be cleaned up and made public as // chiSquared(data: RDD[(V1, V2)]) val startCol = batch * batchSize val endCol = startCol + math.min(batchSize, numCols - startCol) - val pairCounts = data.flatMap { p => - // assume dense vectors - p.features.toArray.slice(startCol, endCol).zipWithIndex.map { case (feature, col) => - (col, feature, p.label) + val pairCounts = data.mapPartitions { iter => + val distinctLabels = mutable.HashSet.empty[Double] + val allDistinctFeatures: Map[Int, mutable.HashSet[Double]] = + Map((startCol until endCol).map(col => (col, mutable.HashSet.empty[Double])): _*) + var i = 1 + iter.flatMap { case LabeledPoint(label, features) => + if (i % 1000 == 0) { + if (distinctLabels.size > maxCategories) { + throw new SparkException(s"Chi-square test expect factors (categorical values) but " + + s"found more than $maxCategories distinct label values.") + } + allDistinctFeatures.foreach { case (col, distinctFeatures) => + if (distinctFeatures.size > maxCategories) { + throw new SparkException(s"Chi-square test expect factors (categorical values) but " + + s"found more than $maxCategories distinct values in column $col.") + } + } + } + i += 1 + distinctLabels += label + features.toArray.view.zipWithIndex.slice(startCol, endCol).map { case (feature, col) => + allDistinctFeatures(col) += feature + (col, feature, label) + } } }.countByValue() diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala index 5bd0521298c14..6de3840b3f198 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala @@ -17,8 +17,11 @@ package org.apache.spark.mllib.stat +import java.util.Random + import org.scalatest.FunSuite +import org.apache.spark.SparkException import org.apache.spark.mllib.linalg.{DenseVector, Matrices, Vectors} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.stat.test.ChiSqTest @@ -107,12 +110,13 @@ class HypothesisTestSuite extends FunSuite with LocalSparkContext { // labels: 1.0 (2 / 6), 0.0 (4 / 6) // feature1: 0.5 (1 / 6), 1.5 (2 / 6), 3.5 (3 / 6) // feature2: 10.0 (1 / 6), 20.0 (1 / 6), 30.0 (2 / 6), 40.0 (2 / 6) - val data = Array(new LabeledPoint(0.0, Vectors.dense(0.5, 10.0)), - new LabeledPoint(0.0, Vectors.dense(1.5, 20.0)), - new LabeledPoint(1.0, Vectors.dense(1.5, 30.0)), - new LabeledPoint(0.0, Vectors.dense(3.5, 30.0)), - new LabeledPoint(0.0, Vectors.dense(3.5, 40.0)), - new LabeledPoint(1.0, Vectors.dense(3.5, 40.0))) + val data = Seq( + LabeledPoint(0.0, Vectors.dense(0.5, 10.0)), + LabeledPoint(0.0, Vectors.dense(1.5, 20.0)), + LabeledPoint(1.0, Vectors.dense(1.5, 30.0)), + LabeledPoint(0.0, Vectors.dense(3.5, 30.0)), + LabeledPoint(0.0, Vectors.dense(3.5, 40.0)), + LabeledPoint(1.0, Vectors.dense(3.5, 40.0))) for (numParts <- List(2, 4, 6, 8)) { val chi = Statistics.chiSqTest(sc.parallelize(data, numParts)) val feature1 = chi(0) @@ -130,10 +134,25 @@ class HypothesisTestSuite extends FunSuite with LocalSparkContext { } // Test that the right number of results is returned - val numCols = 321 - val sparseData = Array(new LabeledPoint(0.0, Vectors.sparse(numCols, Seq((100, 2.0)))), - new LabeledPoint(0.0, Vectors.sparse(numCols, Seq((200, 1.0))))) + val numCols = 1001 + val sparseData = Array( + new LabeledPoint(0.0, Vectors.sparse(numCols, Seq((100, 2.0)))), + new LabeledPoint(0.1, Vectors.sparse(numCols, Seq((200, 1.0))))) val chi = Statistics.chiSqTest(sc.parallelize(sparseData)) assert(chi.size === numCols) + assert(chi(1000) != null) // SPARK-3087 + + // Detect continous features or labels + val random = new Random(11L) + val continuousLabel = + Seq.fill(100000)(LabeledPoint(random.nextDouble(), Vectors.dense(random.nextInt(2)))) + intercept[SparkException] { + Statistics.chiSqTest(sc.parallelize(continuousLabel, 2)) + } + val continuousFeature = + Seq.fill(100000)(LabeledPoint(random.nextInt(2), Vectors.dense(random.nextDouble()))) + intercept[SparkException] { + Statistics.chiSqTest(sc.parallelize(continuousFeature, 2)) + } } } From a5ae720745d744ec29741b49d2d362f362d53fa4 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Sun, 17 Aug 2014 22:29:58 -0700 Subject: [PATCH 170/231] SPARK-2884: Create binary builds in parallel with release script. --- dev/create-release/create-release.sh | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/dev/create-release/create-release.sh b/dev/create-release/create-release.sh index 1867cf4ec46ca..28f26d2368254 100755 --- a/dev/create-release/create-release.sh +++ b/dev/create-release/create-release.sh @@ -117,12 +117,13 @@ make_binary_release() { spark-$RELEASE_VERSION-bin-$NAME.tgz.sha } -make_binary_release "hadoop1" "-Phive -Phive-thriftserver -Dhadoop.version=1.0.4" -make_binary_release "cdh4" "-Phive -Phive-thriftserver -Dhadoop.version=2.0.0-mr1-cdh4.2.0" +make_binary_release "hadoop1" "-Phive -Phive-thriftserver -Dhadoop.version=1.0.4" & +make_binary_release "cdh4" "-Phive -Phive-thriftserver -Dhadoop.version=2.0.0-mr1-cdh4.2.0" & make_binary_release "hadoop2" \ - "-Phive -Phive-thriftserver -Pyarn -Phadoop-2.2 -Dhadoop.version=2.2.0 -Pyarn.version=2.2.0" + "-Phive -Phive-thriftserver -Pyarn -Phadoop-2.2 -Dhadoop.version=2.2.0 -Pyarn.version=2.2.0" & make_binary_release "hadoop2-without-hive" \ - "-Pyarn -Phadoop-2.2 -Dhadoop.version=2.2.0 -Pyarn.version=2.2.0" + "-Pyarn -Phadoop-2.2 -Dhadoop.version=2.2.0 -Pyarn.version=2.2.0" & +wait # Copy data echo "Copying release tarballs" From 0506539b0e853d474183078814fb0f550bfbbd67 Mon Sep 17 00:00:00 2001 From: Sandy Ryza Date: Sun, 17 Aug 2014 22:39:06 -0700 Subject: [PATCH 171/231] SPARK-2900. aggregate inputBytes per stage Author: Sandy Ryza Closes #1826 from sryza/sandy-spark-2900 and squashes the following commits: 43f9091 [Sandy Ryza] SPARK-2900 (cherry picked from commit df652ea02a3e42d987419308ef14874300347373) Signed-off-by: Patrick Wendell --- .../org/apache/spark/ui/jobs/JobProgressListener.scala | 6 ++++++ .../apache/spark/ui/jobs/JobProgressListenerSuite.scala | 9 ++++++++- 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala index a3e9566832d06..74cd637d88155 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala @@ -200,6 +200,12 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { stageData.shuffleReadBytes += shuffleReadDelta execSummary.shuffleRead += shuffleReadDelta + val inputBytesDelta = + (taskMetrics.inputMetrics.map(_.bytesRead).getOrElse(0L) + - oldMetrics.flatMap(_.inputMetrics).map(_.bytesRead).getOrElse(0L)) + stageData.inputBytes += inputBytesDelta + execSummary.inputBytes += inputBytesDelta + val diskSpillDelta = taskMetrics.diskBytesSpilled - oldMetrics.map(_.diskBytesSpilled).getOrElse(0L) stageData.diskBytesSpilled += diskSpillDelta diff --git a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala index f5ba31c309277..147ec0bc52e39 100644 --- a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala @@ -22,7 +22,7 @@ import org.scalatest.Matchers import org.apache.spark._ import org.apache.spark.{LocalSparkContext, SparkConf, Success} -import org.apache.spark.executor.{ShuffleWriteMetrics, ShuffleReadMetrics, TaskMetrics} +import org.apache.spark.executor._ import org.apache.spark.scheduler._ import org.apache.spark.util.Utils @@ -150,6 +150,9 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc taskMetrics.executorRunTime = base + 4 taskMetrics.diskBytesSpilled = base + 5 taskMetrics.memoryBytesSpilled = base + 6 + val inputMetrics = new InputMetrics(DataReadMethod.Hadoop) + taskMetrics.inputMetrics = Some(inputMetrics) + inputMetrics.bytesRead = base + 7 taskMetrics } @@ -182,6 +185,8 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc assert(stage1Data.diskBytesSpilled == 205) assert(stage0Data.memoryBytesSpilled == 112) assert(stage1Data.memoryBytesSpilled == 206) + assert(stage0Data.inputBytes == 114) + assert(stage1Data.inputBytes == 207) assert(stage0Data.taskData.get(1234L).get.taskMetrics.get.shuffleReadMetrics.get .totalBlocksFetched == 2) assert(stage0Data.taskData.get(1235L).get.taskMetrics.get.shuffleReadMetrics.get @@ -208,6 +213,8 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc assert(stage1Data.diskBytesSpilled == 610) assert(stage0Data.memoryBytesSpilled == 412) assert(stage1Data.memoryBytesSpilled == 612) + assert(stage0Data.inputBytes == 414) + assert(stage1Data.inputBytes == 614) assert(stage0Data.taskData.get(1234L).get.taskMetrics.get.shuffleReadMetrics.get .totalBlocksFetched == 302) assert(stage1Data.taskData.get(1237L).get.taskMetrics.get.shuffleReadMetrics.get From 708cde99a142c90f5a06c7aa326b622d80022e3d Mon Sep 17 00:00:00 2001 From: Liquan Pei Date: Sun, 17 Aug 2014 23:29:44 -0700 Subject: [PATCH 172/231] [SPARK-3097][MLlib] Word2Vec performance improvement mengxr Please review the code. Adding weights in reduceByKey soon. Only output model entry for words appeared in the partition before merging and use reduceByKey to combine model. In general, this implementation is 30s or so faster than implementation using big array. Author: Liquan Pei Closes #1932 from Ishiihara/Word2Vec-improve2 and squashes the following commits: d5377a9 [Liquan Pei] use syn0Global and syn1Global to represent model cad2011 [Liquan Pei] bug fix for synModify array out of bound 083aa66 [Liquan Pei] update synGlobal in place and reduce synOut size 9075e1c [Liquan Pei] combine syn0Global and syn1Global to synGlobal aa2ab36 [Liquan Pei] use reduceByKey to combine models (cherry picked from commit 3c8fa505900ac158d57de36f6b0fd6da05f8893b) Signed-off-by: Xiangrui Meng --- .../apache/spark/mllib/feature/Word2Vec.scala | 50 +++++++++++++------ 1 file changed, 35 insertions(+), 15 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index ecd49ea2ff533..d2ae62b482aff 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -34,6 +34,7 @@ import org.apache.spark.mllib.rdd.RDDFunctions._ import org.apache.spark.rdd._ import org.apache.spark.util.Utils import org.apache.spark.util.random.XORShiftRandom +import org.apache.spark.util.collection.PrimitiveKeyOpenHashMap /** * Entry in vocabulary @@ -287,11 +288,12 @@ class Word2Vec extends Serializable with Logging { var syn0Global = Array.fill[Float](vocabSize * vectorSize)((initRandom.nextFloat() - 0.5f) / vectorSize) var syn1Global = new Array[Float](vocabSize * vectorSize) - var alpha = startingAlpha for (k <- 1 to numIterations) { val partial = newSentences.mapPartitionsWithIndex { case (idx, iter) => val random = new XORShiftRandom(seed ^ ((idx + 1) << 16) ^ ((-k - 1) << 8)) + val syn0Modify = new Array[Int](vocabSize) + val syn1Modify = new Array[Int](vocabSize) val model = iter.foldLeft((syn0Global, syn1Global, 0, 0)) { case ((syn0, syn1, lastWordCount, wordCount), sentence) => var lwc = lastWordCount @@ -321,7 +323,8 @@ class Word2Vec extends Serializable with Logging { // Hierarchical softmax var d = 0 while (d < bcVocab.value(word).codeLen) { - val l2 = bcVocab.value(word).point(d) * vectorSize + val inner = bcVocab.value(word).point(d) + val l2 = inner * vectorSize // Propagate hidden -> output var f = blas.sdot(vectorSize, syn0, l1, 1, syn1, l2, 1) if (f > -MAX_EXP && f < MAX_EXP) { @@ -330,10 +333,12 @@ class Word2Vec extends Serializable with Logging { val g = ((1 - bcVocab.value(word).code(d) - f) * alpha).toFloat blas.saxpy(vectorSize, g, syn1, l2, 1, neu1e, 0, 1) blas.saxpy(vectorSize, g, syn0, l1, 1, syn1, l2, 1) + syn1Modify(inner) += 1 } d += 1 } blas.saxpy(vectorSize, 1.0f, neu1e, 0, 1, syn0, l1, 1) + syn0Modify(lastWord) += 1 } } a += 1 @@ -342,21 +347,36 @@ class Word2Vec extends Serializable with Logging { } (syn0, syn1, lwc, wc) } - Iterator(model) + val syn0Local = model._1 + val syn1Local = model._2 + val synOut = new PrimitiveKeyOpenHashMap[Int, Array[Float]](vocabSize * 2) + var index = 0 + while(index < vocabSize) { + if (syn0Modify(index) != 0) { + synOut.update(index, syn0Local.slice(index * vectorSize, (index + 1) * vectorSize)) + } + if (syn1Modify(index) != 0) { + synOut.update(index + vocabSize, + syn1Local.slice(index * vectorSize, (index + 1) * vectorSize)) + } + index += 1 + } + Iterator(synOut) } - val (aggSyn0, aggSyn1, _, _) = - partial.treeReduce { case ((syn0_1, syn1_1, lwc_1, wc_1), (syn0_2, syn1_2, lwc_2, wc_2)) => - val n = syn0_1.length - val weight1 = 1.0f * wc_1 / (wc_1 + wc_2) - val weight2 = 1.0f * wc_2 / (wc_1 + wc_2) - blas.sscal(n, weight1, syn0_1, 1) - blas.sscal(n, weight1, syn1_1, 1) - blas.saxpy(n, weight2, syn0_2, 1, syn0_1, 1) - blas.saxpy(n, weight2, syn1_2, 1, syn1_1, 1) - (syn0_1, syn1_1, lwc_1 + lwc_2, wc_1 + wc_2) + val synAgg = partial.flatMap(x => x).reduceByKey { case (v1, v2) => + blas.saxpy(vectorSize, 1.0f, v2, 1, v1, 1) + v1 + }.collect() + var i = 0 + while (i < synAgg.length) { + val index = synAgg(i)._1 + if (index < vocabSize) { + Array.copy(synAgg(i)._2, 0, syn0Global, index * vectorSize, vectorSize) + } else { + Array.copy(synAgg(i)._2, 0, syn1Global, (index - vocabSize) * vectorSize, vectorSize) } - syn0Global = aggSyn0 - syn1Global = aggSyn1 + i += 1 + } } newSentences.unpersist() From 518258f1ba4d79a72e1a97ebebb1b51cd392c503 Mon Sep 17 00:00:00 2001 From: Liquan Pei Date: Sun, 17 Aug 2014 23:30:47 -0700 Subject: [PATCH 173/231] [SPARK-2842][MLlib]Word2Vec documentation mengxr Documentation for Word2Vec Author: Liquan Pei Closes #2003 from Ishiihara/Word2Vec-doc and squashes the following commits: 4ff11d4 [Liquan Pei] minor fix 8d7458f [Liquan Pei] code reformat 6df0dcb [Liquan Pei] add Word2Vec documentation (cherry picked from commit eef779b8d631de971d440051cae21040f4de558f) Signed-off-by: Xiangrui Meng --- docs/mllib-feature-extraction.md | 63 +++++++++++++++++++++++++++++++- 1 file changed, 62 insertions(+), 1 deletion(-) diff --git a/docs/mllib-feature-extraction.md b/docs/mllib-feature-extraction.md index 21453cb9cd8c9..4b3cb715c58c7 100644 --- a/docs/mllib-feature-extraction.md +++ b/docs/mllib-feature-extraction.md @@ -9,4 +9,65 @@ displayTitle: MLlib - Feature Extraction ## Word2Vec -## TFIDF +Word2Vec computes distributed vector representation of words. The main advantage of the distributed +representations is that similar words are close in the vector space, which makes generalization to +novel patterns easier and model estimation more robust. Distributed vector representation is +showed to be useful in many natural language processing applications such as named entity +recognition, disambiguation, parsing, tagging and machine translation. + +### Model + +In our implementation of Word2Vec, we used skip-gram model. The training objective of skip-gram is +to learn word vector representations that are good at predicting its context in the same sentence. +Mathematically, given a sequence of training words `$w_1, w_2, \dots, w_T$`, the objective of the +skip-gram model is to maximize the average log-likelihood +`\[ +\frac{1}{T} \sum_{t = 1}^{T}\sum_{j=-k}^{j=k} \log p(w_{t+j} | w_t) +\]` +where $k$ is the size of the training window. + +In the skip-gram model, every word $w$ is associated with two vectors $u_w$ and $v_w$ which are +vector representations of $w$ as word and context respectively. The probability of correctly +predicting word $w_i$ given word $w_j$ is determined by the softmax model, which is +`\[ +p(w_i | w_j ) = \frac{\exp(u_{w_i}^{\top}v_{w_j})}{\sum_{l=1}^{V} \exp(u_l^{\top}v_{w_j})} +\]` +where $V$ is the vocabulary size. + +The skip-gram model with softmax is expensive because the cost of computing $\log p(w_i | w_j)$ +is proportional to $V$, which can be easily in order of millions. To speed up training of Word2Vec, +we used hierarchical softmax, which reduced the complexity of computing of $\log p(w_i | w_j)$ to +$O(\log(V))$ + +### Example + +The example below demonstrates how to load a text file, parse it as an RDD of `Seq[String]`, +construct a `Word2Vec` instance and then fit a `Word2VecModel` with the input data. Finally, +we display the top 40 synonyms of the specified word. To run the example, first download +the [text8](http://mattmahoney.net/dc/text8.zip) data and extract it to your preferred directory. +Here we assume the extracted file is `text8` and in same directory as you run the spark shell. + +
    +
    +{% highlight scala %} +import org.apache.spark._ +import org.apache.spark.rdd._ +import org.apache.spark.SparkContext._ +import org.apache.spark.mllib.feature.Word2Vec + +val input = sc.textFile("text8").map(line => line.split(" ").toSeq) + +val word2vec = new Word2Vec() + +val model = word2vec.fit(input) + +val synonyms = model.findSynonyms("china", 40) + +for((synonym, cosineSimilarity) <- synonyms) { + println(s"$synonym $cosineSimilarity") +} +{% endhighlight %} +
    +
    + +## TFIDF \ No newline at end of file From e0bc333b6ad36feac5397600fe6948dcb37a8e44 Mon Sep 17 00:00:00 2001 From: Liquan Pei Date: Mon, 18 Aug 2014 01:15:45 -0700 Subject: [PATCH 174/231] [MLlib] Remove transform(dataset: RDD[String]) from Word2Vec public API mengxr Remove transform(dataset: RDD[String]) from public API. Author: Liquan Pei Closes #2010 from Ishiihara/Word2Vec-api and squashes the following commits: 17b1031 [Liquan Pei] remove transform(dataset: RDD[String]) from public API (cherry picked from commit 9306b8c6c8c412b9d0d5cffb6bd7a87784f0f6bf) Signed-off-by: Xiangrui Meng --- .../scala/org/apache/spark/mllib/feature/Word2Vec.scala | 9 --------- 1 file changed, 9 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index d2ae62b482aff..1dcaa2cd2e630 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -434,15 +434,6 @@ class Word2VecModel private[mllib] ( } } - /** - * Transforms an RDD to its vector representation - * @param dataset a an RDD of words - * @return RDD of vector representation - */ - def transform(dataset: RDD[String]): RDD[Vector] = { - dataset.map(word => transform(word)) - } - /** * Find synonyms of a word * @param word a word From 12f16ba3fa1f3cde9f43c094029017f4192b1bac Mon Sep 17 00:00:00 2001 From: Chandan Kumar Date: Mon, 18 Aug 2014 09:52:25 -0700 Subject: [PATCH 175/231] [SPARK-2862] histogram method fails on some choices of bucketCount Author: Chandan Kumar Closes #1787 from nrchandan/spark-2862 and squashes the following commits: a76bbf6 [Chandan Kumar] [SPARK-2862] Fix for a broken test case and add new test cases 4211eea [Chandan Kumar] [SPARK-2862] Add Scala bug id 13854f1 [Chandan Kumar] [SPARK-2862] Use shorthand range notation to avoid Scala bug (cherry picked from commit f45efbb8aaa65bc46d65e77e93076fbc29f4455d) Signed-off-by: Xiangrui Meng --- .../apache/spark/rdd/DoubleRDDFunctions.scala | 15 ++++++++---- .../org/apache/spark/rdd/DoubleRDDSuite.scala | 23 +++++++++++++++++++ 2 files changed, 34 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala index f233544d128f5..e0494ee39657c 100644 --- a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala @@ -95,7 +95,12 @@ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable { * If the elements in RDD do not vary (max == min) always returns a single bucket. */ def histogram(bucketCount: Int): Pair[Array[Double], Array[Long]] = { - // Compute the minimum and the maxium + // Scala's built-in range has issues. See #SI-8782 + def customRange(min: Double, max: Double, steps: Int): IndexedSeq[Double] = { + val span = max - min + Range.Int(0, steps, 1).map(s => min + (s * span) / steps) :+ max + } + // Compute the minimum and the maximum val (max: Double, min: Double) = self.mapPartitions { items => Iterator(items.foldRight(Double.NegativeInfinity, Double.PositiveInfinity)((e: Double, x: Pair[Double, Double]) => @@ -107,9 +112,11 @@ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable { throw new UnsupportedOperationException( "Histogram on either an empty RDD or RDD containing +/-infinity or NaN") } - val increment = (max-min)/bucketCount.toDouble - val range = if (increment != 0) { - Range.Double.inclusive(min, max, increment) + val range = if (min != max) { + // Range.Double.inclusive(min, max, increment) + // The above code doesn't always work. See Scala bug #SI-8782. + // https://issues.scala-lang.org/browse/SI-8782 + customRange(min, max, bucketCount) } else { List(min, min) } diff --git a/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala index a822bd18bfdbd..f89bdb6e07dea 100644 --- a/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala @@ -245,6 +245,29 @@ class DoubleRDDSuite extends FunSuite with SharedSparkContext { assert(histogramBuckets === expectedHistogramBuckets) } + test("WorksWithoutBucketsForLargerDatasets") { + // Verify the case of slighly larger datasets + val rdd = sc.parallelize(6 to 99) + val (histogramBuckets, histogramResults) = rdd.histogram(8) + val expectedHistogramResults = + Array(12, 12, 11, 12, 12, 11, 12, 12) + val expectedHistogramBuckets = + Array(6.0, 17.625, 29.25, 40.875, 52.5, 64.125, 75.75, 87.375, 99.0) + assert(histogramResults === expectedHistogramResults) + assert(histogramBuckets === expectedHistogramBuckets) + } + + test("WorksWithoutBucketsWithIrrationalBucketEdges") { + // Verify the case of buckets with irrational edges. See #SPARK-2862. + val rdd = sc.parallelize(6 to 99) + val (histogramBuckets, histogramResults) = rdd.histogram(9) + val expectedHistogramResults = + Array(11, 10, 11, 10, 10, 11, 10, 10, 11) + assert(histogramResults === expectedHistogramResults) + assert(histogramBuckets(0) === 6.0) + assert(histogramBuckets(9) === 99.0) + } + // Test the failure mode with an invalid RDD test("ThrowsExceptionOnInvalidRDDs") { // infinity From ec0b91edd592cf89be349e0e5ad7553e02f70cd3 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Mon, 18 Aug 2014 10:00:46 -0700 Subject: [PATCH 176/231] SPARK-3096: Include parquet hive serde by default in build A small change - we should just add this dependency. It doesn't have any recursive deps and it's needed for reading have parquet tables. Author: Patrick Wendell Closes #2009 from pwendell/parquet and squashes the following commits: e411f9f [Patrick Wendell] SPARk-309: Include parquet hive serde by default in build (cherry picked from commit 7ae28d1247e4756219016206c51fec1656e3917b) Signed-off-by: Michael Armbrust --- sql/hive/pom.xml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index c18a664e737c8..1e689e6d6dcf2 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -35,6 +35,11 @@ + + com.twitter + parquet-hive-bundle + 1.5.0 + org.apache.spark spark-core_${scala.binary.version} From 55e9dd637bdef3a2acf56af95410219e23c9502a Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Mon, 18 Aug 2014 10:05:52 -0700 Subject: [PATCH 177/231] [SPARK-3084] [SQL] Collect broadcasted tables in parallel in joins BroadcastHashJoin has a broadcastFuture variable that tries to collect the broadcasted table in a separate thread, but this doesn't help because it's a lazy val that only gets initialized when you attempt to build the RDD. Thus queries that broadcast multiple tables would collect and broadcast them sequentially. I changed this to a val to let it start collecting right when the operator is created. Author: Matei Zaharia Closes #1990 from mateiz/spark-3084 and squashes the following commits: f468766 [Matei Zaharia] [SPARK-3084] Collect broadcasted tables in parallel in joins (cherry picked from commit 6a13dca12fac06f3af892ffcc8922cc84f91b786) Signed-off-by: Michael Armbrust --- .../src/main/scala/org/apache/spark/sql/execution/joins.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala index c86811e838bd8..481bb8c05e71b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala @@ -424,7 +424,7 @@ case class BroadcastHashJoin( UnspecifiedDistribution :: UnspecifiedDistribution :: Nil @transient - lazy val broadcastFuture = future { + val broadcastFuture = future { sparkContext.broadcast(buildPlan.executeCollect()) } From 4da76fc81c224b04bd652c4a72fb77516a32de0c Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Mon, 18 Aug 2014 10:45:24 -0700 Subject: [PATCH 178/231] [SPARK-3085] [SQL] Use compact data structures in SQL joins This reuses the CompactBuffer from Spark Core to save memory and pointer dereferences. I also tried AppendOnlyMap instead of java.util.HashMap but unfortunately that slows things down because it seems to do more equals() calls and the equals on GenericRow, and especially JoinedRow, is pretty expensive. Author: Matei Zaharia Closes #1993 from mateiz/spark-3085 and squashes the following commits: 188221e [Matei Zaharia] Remove unneeded import 5f903ee [Matei Zaharia] [SPARK-3085] [SQL] Use compact data structures in SQL joins (cherry picked from commit 4bf3de71074053af94f077c99e9c65a1962739e1) Signed-off-by: Michael Armbrust --- .../apache/spark/sql/execution/joins.scala | 67 +++++++++---------- 1 file changed, 33 insertions(+), 34 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala index 481bb8c05e71b..b08f9aacc1fcb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala @@ -19,16 +19,15 @@ package org.apache.spark.sql.execution import java.util.{HashMap => JavaHashMap} -import scala.collection.mutable.{ArrayBuffer, BitSet} import scala.concurrent.ExecutionContext.Implicits.global import scala.concurrent._ import scala.concurrent.duration._ import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.util.collection.CompactBuffer @DeveloperApi sealed abstract class BuildSide @@ -67,7 +66,7 @@ trait HashJoin { def joinIterators(buildIter: Iterator[Row], streamIter: Iterator[Row]): Iterator[Row] = { // TODO: Use Spark's HashMap implementation. - val hashTable = new java.util.HashMap[Row, ArrayBuffer[Row]]() + val hashTable = new java.util.HashMap[Row, CompactBuffer[Row]]() var currentRow: Row = null // Create a mapping of buildKeys -> rows @@ -77,7 +76,7 @@ trait HashJoin { if (!rowKey.anyNull) { val existingMatchList = hashTable.get(rowKey) val matchList = if (existingMatchList == null) { - val newMatchList = new ArrayBuffer[Row]() + val newMatchList = new CompactBuffer[Row]() hashTable.put(rowKey, newMatchList) newMatchList } else { @@ -89,7 +88,7 @@ trait HashJoin { new Iterator[Row] { private[this] var currentStreamedRow: Row = _ - private[this] var currentHashMatches: ArrayBuffer[Row] = _ + private[this] var currentHashMatches: CompactBuffer[Row] = _ private[this] var currentMatchPosition: Int = -1 // Mutable per row objects. @@ -140,7 +139,7 @@ trait HashJoin { /** * :: DeveloperApi :: - * Performs a hash based outer join for two child relations by shuffling the data using + * Performs a hash based outer join for two child relations by shuffling the data using * the join keys. This operator requires loading the associated partition in both side into memory. */ @DeveloperApi @@ -179,26 +178,26 @@ case class HashOuterJoin( @transient private[this] lazy val EMPTY_LIST = Seq.empty[Row] // TODO we need to rewrite all of the iterators with our own implementation instead of the Scala - // iterator for performance purpose. + // iterator for performance purpose. private[this] def leftOuterIterator( key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row]): Iterator[Row] = { val joinedRow = new JoinedRow() val rightNullRow = new GenericRow(right.output.length) - val boundCondition = + val boundCondition = condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true) - leftIter.iterator.flatMap { l => + leftIter.iterator.flatMap { l => joinedRow.withLeft(l) var matched = false - (if (!key.anyNull) rightIter.collect { case r if (boundCondition(joinedRow.withRight(r))) => + (if (!key.anyNull) rightIter.collect { case r if (boundCondition(joinedRow.withRight(r))) => matched = true joinedRow.copy } else { Nil }) ++ DUMMY_LIST.filter(_ => !matched).map( _ => { // DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row, - // as we don't know whether we need to append it until finish iterating all of the + // as we don't know whether we need to append it until finish iterating all of the // records in right side. // If we didn't get any proper row, then append a single row with empty right joinedRow.withRight(rightNullRow).copy @@ -210,20 +209,20 @@ case class HashOuterJoin( key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row]): Iterator[Row] = { val joinedRow = new JoinedRow() val leftNullRow = new GenericRow(left.output.length) - val boundCondition = + val boundCondition = condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true) - rightIter.iterator.flatMap { r => + rightIter.iterator.flatMap { r => joinedRow.withRight(r) var matched = false - (if (!key.anyNull) leftIter.collect { case l if (boundCondition(joinedRow.withLeft(l))) => + (if (!key.anyNull) leftIter.collect { case l if (boundCondition(joinedRow.withLeft(l))) => matched = true joinedRow.copy } else { Nil }) ++ DUMMY_LIST.filter(_ => !matched).map( _ => { // DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row, - // as we don't know whether we need to append it until finish iterating all of the + // as we don't know whether we need to append it until finish iterating all of the // records in left side. // If we didn't get any proper row, then append a single row with empty left. joinedRow.withLeft(leftNullRow).copy @@ -236,7 +235,7 @@ case class HashOuterJoin( val joinedRow = new JoinedRow() val leftNullRow = new GenericRow(left.output.length) val rightNullRow = new GenericRow(right.output.length) - val boundCondition = + val boundCondition = condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true) if (!key.anyNull) { @@ -246,8 +245,8 @@ case class HashOuterJoin( leftIter.iterator.flatMap[Row] { l => joinedRow.withLeft(l) var matched = false - rightIter.zipWithIndex.collect { - // 1. For those matched (satisfy the join condition) records with both sides filled, + rightIter.zipWithIndex.collect { + // 1. For those matched (satisfy the join condition) records with both sides filled, // append them directly case (r, idx) if (boundCondition(joinedRow.withRight(r)))=> { @@ -260,7 +259,7 @@ case class HashOuterJoin( // 2. For those unmatched records in left, append additional records with empty right. // DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row, - // as we don't know whether we need to append it until finish iterating all + // as we don't know whether we need to append it until finish iterating all // of the records in right side. // If we didn't get any proper row, then append a single row with empty right. joinedRow.withRight(rightNullRow).copy @@ -268,8 +267,8 @@ case class HashOuterJoin( } ++ rightIter.zipWithIndex.collect { // 3. For those unmatched records in right, append additional records with empty left. - // Re-visiting the records in right, and append additional row with empty left, if its not - // in the matched set. + // Re-visiting the records in right, and append additional row with empty left, if its not + // in the matched set. case (r, idx) if (!rightMatchedSet.contains(idx)) => { joinedRow(leftNullRow, r).copy } @@ -284,15 +283,15 @@ case class HashOuterJoin( } private[this] def buildHashTable( - iter: Iterator[Row], keyGenerator: Projection): JavaHashMap[Row, ArrayBuffer[Row]] = { - val hashTable = new JavaHashMap[Row, ArrayBuffer[Row]]() + iter: Iterator[Row], keyGenerator: Projection): JavaHashMap[Row, CompactBuffer[Row]] = { + val hashTable = new JavaHashMap[Row, CompactBuffer[Row]]() while (iter.hasNext) { val currentRow = iter.next() val rowKey = keyGenerator(currentRow) var existingMatchList = hashTable.get(rowKey) if (existingMatchList == null) { - existingMatchList = new ArrayBuffer[Row]() + existingMatchList = new CompactBuffer[Row]() hashTable.put(rowKey, existingMatchList) } @@ -311,20 +310,20 @@ case class HashOuterJoin( val rightHashTable = buildHashTable(rightIter, newProjection(rightKeys, right.output)) import scala.collection.JavaConversions._ - val boundCondition = + val boundCondition = condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true) joinType match { case LeftOuter => leftHashTable.keysIterator.flatMap { key => - leftOuterIterator(key, leftHashTable.getOrElse(key, EMPTY_LIST), + leftOuterIterator(key, leftHashTable.getOrElse(key, EMPTY_LIST), rightHashTable.getOrElse(key, EMPTY_LIST)) } case RightOuter => rightHashTable.keysIterator.flatMap { key => - rightOuterIterator(key, leftHashTable.getOrElse(key, EMPTY_LIST), + rightOuterIterator(key, leftHashTable.getOrElse(key, EMPTY_LIST), rightHashTable.getOrElse(key, EMPTY_LIST)) } case FullOuter => (leftHashTable.keySet ++ rightHashTable.keySet).iterator.flatMap { key => - fullOuterIterator(key, - leftHashTable.getOrElse(key, EMPTY_LIST), + fullOuterIterator(key, + leftHashTable.getOrElse(key, EMPTY_LIST), rightHashTable.getOrElse(key, EMPTY_LIST)) } case x => throw new Exception(s"HashOuterJoin should not take $x as the JoinType") @@ -550,7 +549,7 @@ case class BroadcastNestedLoopJoin( /** All rows that either match both-way, or rows from streamed joined with nulls. */ val matchesOrStreamedRowsWithNulls = streamed.execute().mapPartitions { streamedIter => - val matchedRows = new ArrayBuffer[Row] + val matchedRows = new CompactBuffer[Row] // TODO: Use Spark's BitSet. val includedBroadcastTuples = new scala.collection.mutable.BitSet(broadcastedRelation.value.size) @@ -602,20 +601,20 @@ case class BroadcastNestedLoopJoin( val rightNulls = new GenericMutableRow(right.output.size) /** Rows from broadcasted joined with nulls. */ val broadcastRowsWithNulls: Seq[Row] = { - val arrBuf: collection.mutable.ArrayBuffer[Row] = collection.mutable.ArrayBuffer() + val buf: CompactBuffer[Row] = new CompactBuffer() var i = 0 val rel = broadcastedRelation.value while (i < rel.length) { if (!allIncludedBroadcastTuples.contains(i)) { (joinType, buildSide) match { - case (RightOuter | FullOuter, BuildRight) => arrBuf += new JoinedRow(leftNulls, rel(i)) - case (LeftOuter | FullOuter, BuildLeft) => arrBuf += new JoinedRow(rel(i), rightNulls) + case (RightOuter | FullOuter, BuildRight) => buf += new JoinedRow(leftNulls, rel(i)) + case (LeftOuter | FullOuter, BuildLeft) => buf += new JoinedRow(rel(i), rightNulls) case _ => } } i += 1 } - arrBuf.toSeq + buf.toSeq } // TODO: Breaks lineage. From 496f62d9a98067256d8a51fd1e7a485ff6492fa8 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Mon, 18 Aug 2014 10:52:20 -0700 Subject: [PATCH 179/231] SPARK-3025 [SQL]: Allow JDBC clients to set a fair scheduler pool This definitely needs review as I am not familiar with this part of Spark. I tested this locally and it did seem to work. Author: Patrick Wendell Closes #1937 from pwendell/scheduler and squashes the following commits: b858e33 [Patrick Wendell] SPARK-3025: Allow JDBC clients to set a fair scheduler pool (cherry picked from commit 6bca8898a1aa4ca7161492229bac1748b3da2ad7) Signed-off-by: Michael Armbrust --- docs/sql-programming-guide.md | 5 ++++ .../scala/org/apache/spark/sql/SQLConf.scala | 3 +++ .../server/SparkSQLOperationManager.scala | 27 ++++++++++++++----- 3 files changed, 28 insertions(+), 7 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index cd6543945c385..34accade36ea9 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -605,6 +605,11 @@ Configuration of Hive is done by placing your `hive-site.xml` file in `conf/`. You may also use the beeline script comes with Hive. +To set a [Fair Scheduler](job-scheduling.html#fair-scheduler-pools) pool for a JDBC client session, +users can set the `spark.sql.thriftserver.scheduler.pool` variable: + + SET spark.sql.thriftserver.scheduler.pool=accounting; + ### Migration Guide for Shark Users #### Reducer number diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 90de11182e605..56face2992bcf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -33,6 +33,9 @@ private[spark] object SQLConf { val DIALECT = "spark.sql.dialect" val PARQUET_BINARY_AS_STRING = "spark.sql.parquet.binaryAsString" + // This is only used for the thriftserver + val THRIFTSERVER_POOL = "spark.sql.thriftserver.scheduler.pool" + object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala index 9338e8121b0fe..699a1103f3248 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala @@ -17,24 +17,24 @@ package org.apache.spark.sql.hive.thriftserver.server -import scala.collection.JavaConversions._ -import scala.collection.mutable.ArrayBuffer -import scala.math.{random, round} - import java.sql.Timestamp import java.util.{Map => JMap} +import scala.collection.JavaConversions._ +import scala.collection.mutable.{ArrayBuffer, Map} +import scala.math.{random, round} + import org.apache.hadoop.hive.common.`type`.HiveDecimal import org.apache.hadoop.hive.metastore.api.FieldSchema import org.apache.hive.service.cli._ import org.apache.hive.service.cli.operation.{ExecuteStatementOperation, Operation, OperationManager} import org.apache.hive.service.cli.session.HiveSession - import org.apache.spark.Logging +import org.apache.spark.sql.{Row => SparkRow, SQLConf, SchemaRDD} +import org.apache.spark.sql.catalyst.plans.logical.SetCommand import org.apache.spark.sql.catalyst.types._ -import org.apache.spark.sql.hive.thriftserver.ReflectionUtils import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes} -import org.apache.spark.sql.{SchemaRDD, Row => SparkRow} +import org.apache.spark.sql.hive.thriftserver.ReflectionUtils /** * Executes queries using Spark SQL, and maintains a list of handles to active queries. @@ -43,6 +43,9 @@ class SparkSQLOperationManager(hiveContext: HiveContext) extends OperationManage val handleToOperation = ReflectionUtils .getSuperField[JMap[OperationHandle, Operation]](this, "handleToOperation") + // TODO: Currenlty this will grow infinitely, even as sessions expire + val sessionToActivePool = Map[HiveSession, String]() + override def newExecuteStatementOperation( parentSession: HiveSession, statement: String, @@ -165,8 +168,18 @@ class SparkSQLOperationManager(hiveContext: HiveContext) extends OperationManage try { result = hiveContext.sql(statement) logDebug(result.queryExecution.toString()) + result.queryExecution.logical match { + case SetCommand(Some(key), Some(value)) if (key == SQLConf.THRIFTSERVER_POOL) => + sessionToActivePool(parentSession) = value + logInfo(s"Setting spark.scheduler.pool=$value for future statements in this session.") + case _ => + } + val groupId = round(random * 1000000).toString hiveContext.sparkContext.setJobGroup(groupId, statement) + sessionToActivePool.get(parentSession).foreach { pool => + hiveContext.sparkContext.setLocalProperty("spark.scheduler.pool", pool) + } iter = { val resultRdd = result.queryExecution.toRdd val useIncrementalCollect = From 2ae2857986e94d5a8bd5f4660eabe5689463bd21 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Mon, 18 Aug 2014 11:00:10 -0700 Subject: [PATCH 180/231] [SPARK-3091] [SQL] Add support for caching metadata on Parquet files For larger Parquet files, reading the file footers (which is done in parallel on up to 5 threads) and HDFS block locations (which is serial) can take multiple seconds. We can add an option to cache this data within FilteringParquetInputFormat. Unfortunately ParquetInputFormat only caches footers within each instance of ParquetInputFormat, not across them. Note: this PR leaves this turned off by default for 1.1, but I believe it's safe to turn it on after. The keys in the hash maps are FileStatus objects that include a modification time, so this will work fine if files are modified. The location cache could become invalid if files have moved within HDFS, but that's rare so I just made it invalidate entries every 15 minutes. Author: Matei Zaharia Closes #2005 from mateiz/parquet-cache and squashes the following commits: dae8efe [Matei Zaharia] Bug fix c71e9ed [Matei Zaharia] Handle empty statuses directly 22072b0 [Matei Zaharia] Use Guava caches and add a config option for caching metadata 8fb56ce [Matei Zaharia] Cache file block locations too 453bd21 [Matei Zaharia] Bug fix 4094df6 [Matei Zaharia] First attempt at caching Parquet footers (cherry picked from commit 9eb74c7d2cbe127dd4c32bf1a8318497b2fb55b6) Signed-off-by: Michael Armbrust --- .../scala/org/apache/spark/sql/SQLConf.scala | 1 + .../sql/parquet/ParquetTableOperations.scala | 84 ++++++++++++++++--- 2 files changed, 72 insertions(+), 13 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 56face2992bcf..4f2adb006fbc7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -32,6 +32,7 @@ private[spark] object SQLConf { val CODEGEN_ENABLED = "spark.sql.codegen" val DIALECT = "spark.sql.dialect" val PARQUET_BINARY_AS_STRING = "spark.sql.parquet.binaryAsString" + val PARQUET_CACHE_METADATA = "spark.sql.parquet.cacheMetadata" // This is only used for the thriftserver val THRIFTSERVER_POOL = "spark.sql.thriftserver.scheduler.pool" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala index 759a2a586b926..c6dca10f6ad7c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala @@ -17,22 +17,23 @@ package org.apache.spark.sql.parquet -import scala.collection.JavaConversions._ -import scala.collection.mutable -import scala.util.Try - import java.io.IOException import java.lang.{Long => JLong} import java.text.SimpleDateFormat -import java.util.{Date, List => JList} +import java.util.concurrent.{Callable, TimeUnit} +import java.util.{ArrayList, Collections, Date, List => JList} +import scala.collection.JavaConversions._ +import scala.collection.mutable +import scala.util.Try + +import com.google.common.cache.CacheBuilder import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileStatus, Path} +import org.apache.hadoop.fs.{BlockLocation, FileStatus, Path} import org.apache.hadoop.mapreduce._ import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFormat} import org.apache.hadoop.mapreduce.lib.output.{FileOutputFormat => NewFileOutputFormat} import org.apache.hadoop.mapreduce.lib.output.FileOutputCommitter - import parquet.hadoop._ import parquet.hadoop.api.{InitContext, ReadSupport} import parquet.hadoop.metadata.GlobalMetaData @@ -41,7 +42,7 @@ import parquet.io.ParquetDecodingException import parquet.schema.MessageType import org.apache.spark.rdd.RDD -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.SQLConf import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, Row} import org.apache.spark.sql.execution.{LeafNode, SparkPlan, UnaryNode} import org.apache.spark.{Logging, SerializableWritable, TaskContext} @@ -96,6 +97,11 @@ case class ParquetTableScan( ParquetFilters.serializeFilterExpressions(columnPruningPred, conf) } + // Tell FilteringParquetRowInputFormat whether it's okay to cache Parquet and FS metadata + conf.set( + SQLConf.PARQUET_CACHE_METADATA, + sqlContext.getConf(SQLConf.PARQUET_CACHE_METADATA, "false")) + sc.newAPIHadoopRDD( conf, classOf[FilteringParquetRowInputFormat], @@ -323,10 +329,40 @@ private[parquet] class FilteringParquetRowInputFormat } override def getFooters(jobContext: JobContext): JList[Footer] = { + import FilteringParquetRowInputFormat.footerCache + if (footers eq null) { + val conf = ContextUtil.getConfiguration(jobContext) + val cacheMetadata = conf.getBoolean(SQLConf.PARQUET_CACHE_METADATA, false) val statuses = listStatus(jobContext) fileStatuses = statuses.map(file => file.getPath -> file).toMap - footers = getFooters(ContextUtil.getConfiguration(jobContext), statuses) + if (statuses.isEmpty) { + footers = Collections.emptyList[Footer] + } else if (!cacheMetadata) { + // Read the footers from HDFS + footers = getFooters(conf, statuses) + } else { + // Read only the footers that are not in the footerCache + val foundFooters = footerCache.getAllPresent(statuses) + val toFetch = new ArrayList[FileStatus] + for (s <- statuses) { + if (!foundFooters.containsKey(s)) { + toFetch.add(s) + } + } + val newFooters = new mutable.HashMap[FileStatus, Footer] + if (toFetch.size > 0) { + val fetched = getFooters(conf, toFetch) + for ((status, i) <- toFetch.zipWithIndex) { + newFooters(status) = fetched.get(i) + } + footerCache.putAll(newFooters) + } + footers = new ArrayList[Footer](statuses.size) + for (status <- statuses) { + footers.add(newFooters.getOrElse(status, foundFooters.get(status))) + } + } } footers @@ -339,6 +375,10 @@ private[parquet] class FilteringParquetRowInputFormat configuration: Configuration, footers: JList[Footer]): JList[ParquetInputSplit] = { + import FilteringParquetRowInputFormat.blockLocationCache + + val cacheMetadata = configuration.getBoolean(SQLConf.PARQUET_CACHE_METADATA, false) + val maxSplitSize: JLong = configuration.getLong("mapred.max.split.size", Long.MaxValue) val minSplitSize: JLong = Math.max(getFormatMinSplitSize(), configuration.getLong("mapred.min.split.size", 0L)) @@ -366,16 +406,23 @@ private[parquet] class FilteringParquetRowInputFormat for (footer <- footers) { val fs = footer.getFile.getFileSystem(configuration) val file = footer.getFile - val fileStatus = fileStatuses.getOrElse(file, fs.getFileStatus(file)) + val status = fileStatuses.getOrElse(file, fs.getFileStatus(file)) val parquetMetaData = footer.getParquetMetadata val blocks = parquetMetaData.getBlocks - val fileBlockLocations = fs.getFileBlockLocations(fileStatus, 0, fileStatus.getLen) + var blockLocations: Array[BlockLocation] = null + if (!cacheMetadata) { + blockLocations = fs.getFileBlockLocations(status, 0, status.getLen) + } else { + blockLocations = blockLocationCache.get(status, new Callable[Array[BlockLocation]] { + def call(): Array[BlockLocation] = fs.getFileBlockLocations(status, 0, status.getLen) + }) + } splits.addAll( generateSplits.invoke( null, blocks, - fileBlockLocations, - fileStatus, + blockLocations, + status, parquetMetaData.getFileMetaData, readContext.getRequestedSchema.toString, readContext.getReadSupportMetadata, @@ -387,6 +434,17 @@ private[parquet] class FilteringParquetRowInputFormat } } +private[parquet] object FilteringParquetRowInputFormat { + private val footerCache = CacheBuilder.newBuilder() + .maximumSize(20000) + .build[FileStatus, Footer]() + + private val blockLocationCache = CacheBuilder.newBuilder() + .maximumSize(20000) + .expireAfterWrite(15, TimeUnit.MINUTES) // Expire locations since HDFS files might move + .build[FileStatus, Array[BlockLocation]]() +} + private[parquet] object FileSystemHelper { def listFiles(pathStr: String, conf: Configuration): Seq[Path] = { val origPath = new Path(pathStr) From cc4015d2fa3785b92e6ab079b3abcf17627f7c56 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Mon, 18 Aug 2014 13:17:10 -0700 Subject: [PATCH 181/231] [SPARK-2406][SQL] Initial support for using ParquetTableScan to read HiveMetaStore tables. This PR adds an experimental flag `spark.sql.hive.convertMetastoreParquet` that when true causes the planner to detects tables that use Hive's Parquet SerDe and instead plans them using Spark SQL's native `ParquetTableScan`. Author: Michael Armbrust Author: Yin Huai Closes #1819 from marmbrus/parquetMetastore and squashes the following commits: 1620079 [Michael Armbrust] Revert "remove hive parquet bundle" cc30430 [Michael Armbrust] Merge remote-tracking branch 'origin/master' into parquetMetastore 4f3d54f [Michael Armbrust] fix style 41ebc5f [Michael Armbrust] remove hive parquet bundle a43e0da [Michael Armbrust] Merge remote-tracking branch 'origin/master' into parquetMetastore 4c4dc19 [Michael Armbrust] Fix bug with tree splicing. ebb267e [Michael Armbrust] include parquet hive to tests pass (Remove this later). c0d9b72 [Michael Armbrust] Avoid creating a HadoopRDD per partition. Add dirty hacks to retrieve partition values from the InputSplit. 8cdc93c [Michael Armbrust] Merge pull request #8 from yhuai/parquetMetastore a0baec7 [Yin Huai] Partitioning columns can be resolved. 1161338 [Michael Armbrust] Add a test to make sure conversion is actually happening 212d5cd [Michael Armbrust] Initial support for using ParquetTableScan to read HiveMetaStore tables. (cherry picked from commit 3abd0c1cda09bb575adc99847a619bc84af37fd0) Signed-off-by: Michael Armbrust --- project/SparkBuild.scala | 1 - .../spark/sql/execution/basicOperators.scala | 12 ++ .../spark/sql/parquet/ParquetRelation.scala | 8 +- .../sql/parquet/ParquetTableOperations.scala | 74 ++++++-- .../apache/spark/sql/hive/HiveContext.scala | 9 + .../spark/sql/hive/HiveStrategies.scala | 119 +++++++++++- .../sql/hive/parquet/FakeParquetSerDe.scala | 56 ++++++ .../sql/parquet/ParquetMetastoreSuite.scala | 171 ++++++++++++++++++ 8 files changed, 427 insertions(+), 23 deletions(-) create mode 100644 sql/hive/src/main/scala/org/apache/spark/sql/hive/parquet/FakeParquetSerDe.scala create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/parquet/ParquetMetastoreSuite.scala diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 63a285b81a60c..49d52aefca17a 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -228,7 +228,6 @@ object SQL { object Hive { lazy val settings = Seq( - javaOptions += "-XX:MaxPermSize=1g", // Multiple queries rely on the TestHive singleton. See comments there for more details. parallelExecution in Test := false, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 0027f3cf1fc79..f9dfa3c92f1eb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -303,3 +303,15 @@ case class Intersect(left: SparkPlan, right: SparkPlan) extends BinaryNode { left.execute().map(_.copy()).intersection(right.execute().map(_.copy())) } } + +/** + * :: DeveloperApi :: + * A plan node that does nothing but lie about the output of its child. Used to spice a + * (hopefully structurally equivalent) tree from a different optimization sequence into an already + * resolved tree. + */ +@DeveloperApi +case class OutputFaker(output: Seq[Attribute], child: SparkPlan) extends SparkPlan { + def children = child :: Nil + def execute() = child.execute() +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala index 053b2a154389c..1713ae6fb5d93 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala @@ -47,7 +47,8 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, LeafNode} private[sql] case class ParquetRelation( path: String, @transient conf: Option[Configuration], - @transient sqlContext: SQLContext) + @transient sqlContext: SQLContext, + partitioningAttributes: Seq[Attribute] = Nil) extends LeafNode with MultiInstanceRelation { self: Product => @@ -61,12 +62,13 @@ private[sql] case class ParquetRelation( /** Attributes */ override val output = + partitioningAttributes ++ ParquetTypesConverter.readSchemaFromFile( - new Path(path), + new Path(path.split(",").head), conf, sqlContext.isParquetBinaryAsString) - override def newInstance = ParquetRelation(path, conf, sqlContext).asInstanceOf[this.type] + override def newInstance() = ParquetRelation(path, conf, sqlContext).asInstanceOf[this.type] // Equals must also take into account the output attributes so that we can distinguish between // different instances of the same relation, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala index c6dca10f6ad7c..f6cfab736d98a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala @@ -34,6 +34,7 @@ import org.apache.hadoop.mapreduce._ import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFormat} import org.apache.hadoop.mapreduce.lib.output.{FileOutputFormat => NewFileOutputFormat} import org.apache.hadoop.mapreduce.lib.output.FileOutputCommitter + import parquet.hadoop._ import parquet.hadoop.api.{InitContext, ReadSupport} import parquet.hadoop.metadata.GlobalMetaData @@ -42,6 +43,7 @@ import parquet.io.ParquetDecodingException import parquet.schema.MessageType import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.SQLConf import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, Row} import org.apache.spark.sql.execution.{LeafNode, SparkPlan, UnaryNode} @@ -60,11 +62,18 @@ case class ParquetTableScan( // The resolution of Parquet attributes is case sensitive, so we resolve the original attributes // by exprId. note: output cannot be transient, see // https://issues.apache.org/jira/browse/SPARK-1367 - val output = attributes.map { a => - relation.output - .find(o => o.exprId == a.exprId) - .getOrElse(sys.error(s"Invalid parquet attribute $a in ${relation.output.mkString(",")}")) - } + val normalOutput = + attributes + .filterNot(a => relation.partitioningAttributes.map(_.exprId).contains(a.exprId)) + .flatMap(a => relation.output.find(o => o.exprId == a.exprId)) + + val partOutput = + attributes.flatMap(a => relation.partitioningAttributes.find(o => o.exprId == a.exprId)) + + def output = partOutput ++ normalOutput + + assert(normalOutput.size + partOutput.size == attributes.size, + s"$normalOutput + $partOutput != $attributes, ${relation.output}") override def execute(): RDD[Row] = { val sc = sqlContext.sparkContext @@ -72,16 +81,19 @@ case class ParquetTableScan( ParquetInputFormat.setReadSupportClass(job, classOf[RowReadSupport]) val conf: Configuration = ContextUtil.getConfiguration(job) - val qualifiedPath = { - val path = new Path(relation.path) - path.getFileSystem(conf).makeQualified(path) + + relation.path.split(",").foreach { curPath => + val qualifiedPath = { + val path = new Path(curPath) + path.getFileSystem(conf).makeQualified(path) + } + NewFileInputFormat.addInputPath(job, qualifiedPath) } - NewFileInputFormat.addInputPath(job, qualifiedPath) // Store both requested and original schema in `Configuration` conf.set( RowReadSupport.SPARK_ROW_REQUESTED_SCHEMA, - ParquetTypesConverter.convertToString(output)) + ParquetTypesConverter.convertToString(normalOutput)) conf.set( RowWriteSupport.SPARK_ROW_SCHEMA, ParquetTypesConverter.convertToString(relation.output)) @@ -102,13 +114,41 @@ case class ParquetTableScan( SQLConf.PARQUET_CACHE_METADATA, sqlContext.getConf(SQLConf.PARQUET_CACHE_METADATA, "false")) - sc.newAPIHadoopRDD( - conf, - classOf[FilteringParquetRowInputFormat], - classOf[Void], - classOf[Row]) - .map(_._2) - .filter(_ != null) // Parquet's record filters may produce null values + val baseRDD = + new org.apache.spark.rdd.NewHadoopRDD( + sc, + classOf[FilteringParquetRowInputFormat], + classOf[Void], + classOf[Row], + conf) + + if (partOutput.nonEmpty) { + baseRDD.mapPartitionsWithInputSplit { case (split, iter) => + val partValue = "([^=]+)=([^=]+)".r + val partValues = + split.asInstanceOf[parquet.hadoop.ParquetInputSplit] + .getPath + .toString + .split("/") + .flatMap { + case partValue(key, value) => Some(key -> value) + case _ => None + }.toMap + + val partitionRowValues = + partOutput.map(a => Cast(Literal(partValues(a.name)), a.dataType).eval(EmptyRow)) + + new Iterator[Row] { + private[this] val joinedRow = new JoinedRow(Row(partitionRowValues:_*), null) + + def hasNext = iter.hasNext + + def next() = joinedRow.withRight(iter.next()._2) + } + } + } else { + baseRDD.map(_._2) + }.filter(_ != null) // Parquet's record filters may produce null values } /** diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index a8da676ffa0e0..ff32c7c90a0d2 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -79,6 +79,14 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { // Change the default SQL dialect to HiveQL override private[spark] def dialect: String = getConf(SQLConf.DIALECT, "hiveql") + /** + * When true, enables an experimental feature where metastore tables that use the parquet SerDe + * are automatically converted to use the Spark SQL parquet table scan, instead of the Hive + * SerDe. + */ + private[spark] def convertMetastoreParquet: Boolean = + getConf("spark.sql.hive.convertMetastoreParquet", "false") == "true" + override protected[sql] def executePlan(plan: LogicalPlan): this.QueryExecution = new this.QueryExecution { val logical = plan } @@ -326,6 +334,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { TakeOrdered, ParquetOperations, InMemoryScans, + ParquetConversion, // Must be before HiveTableScans HiveTableScans, DataSinks, Scripts, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index 5fcc1bd4b9adf..389ace726d205 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -17,14 +17,20 @@ package org.apache.spark.sql.hive -import org.apache.spark.sql.SQLContext +import org.apache.spark.annotation.Experimental +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.GeneratePredicate import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, LowerCaseSchema} import org.apache.spark.sql.execution._ import org.apache.spark.sql.hive.execution._ import org.apache.spark.sql.columnar.InMemoryRelation +import org.apache.spark.sql.parquet.{ParquetRelation, ParquetTableScan} + +import scala.collection.JavaConversions._ private[hive] trait HiveStrategies { // Possibly being too clever with types here... or not clever enough. @@ -32,6 +38,115 @@ private[hive] trait HiveStrategies { val hiveContext: HiveContext + /** + * :: Experimental :: + * Finds table scans that would use the Hive SerDe and replaces them with our own native parquet + * table scan operator. + * + * TODO: Much of this logic is duplicated in HiveTableScan. Ideally we would do some refactoring + * but since this is after the code freeze for 1.1 all logic is here to minimize disruption. + * + * Other issues: + * - Much of this logic assumes case insensitive resolution. + */ + @Experimental + object ParquetConversion extends Strategy { + implicit class LogicalPlanHacks(s: SchemaRDD) { + def lowerCase = + new SchemaRDD(s.sqlContext, LowerCaseSchema(s.logicalPlan)) + + def addPartitioningAttributes(attrs: Seq[Attribute]) = + new SchemaRDD( + s.sqlContext, + s.logicalPlan transform { + case p: ParquetRelation => p.copy(partitioningAttributes = attrs) + }) + } + + implicit class PhysicalPlanHacks(originalPlan: SparkPlan) { + def fakeOutput(newOutput: Seq[Attribute]) = + OutputFaker( + originalPlan.output.map(a => + newOutput.find(a.name.toLowerCase == _.name.toLowerCase) + .getOrElse( + sys.error(s"Can't find attribute $a to fake in set ${newOutput.mkString(",")}"))), + originalPlan) + } + + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case PhysicalOperation(projectList, predicates, relation: MetastoreRelation) + if relation.tableDesc.getSerdeClassName.contains("Parquet") && + hiveContext.convertMetastoreParquet => + + // Filter out all predicates that only deal with partition keys + val partitionKeyIds = relation.partitionKeys.map(_.exprId).toSet + val (pruningPredicates, otherPredicates) = predicates.partition { + _.references.map(_.exprId).subsetOf(partitionKeyIds) + } + + // We are going to throw the predicates and projection back at the whole optimization + // sequence so lets unresolve all the attributes, allowing them to be rebound to the + // matching parquet attributes. + val unresolvedOtherPredicates = otherPredicates.map(_ transform { + case a: AttributeReference => UnresolvedAttribute(a.name) + }).reduceOption(And).getOrElse(Literal(true)) + + val unresolvedProjection = projectList.map(_ transform { + case a: AttributeReference => UnresolvedAttribute(a.name) + }) + + if (relation.hiveQlTable.isPartitioned) { + val rawPredicate = pruningPredicates.reduceOption(And).getOrElse(Literal(true)) + // Translate the predicate so that it automatically casts the input values to the correct + // data types during evaluation + val castedPredicate = rawPredicate transform { + case a: AttributeReference => + val idx = relation.partitionKeys.indexWhere(a.exprId == _.exprId) + val key = relation.partitionKeys(idx) + Cast(BoundReference(idx, StringType, nullable = true), key.dataType) + } + + val inputData = new GenericMutableRow(relation.partitionKeys.size) + val pruningCondition = + if(codegenEnabled) { + GeneratePredicate(castedPredicate) + } else { + InterpretedPredicate(castedPredicate) + } + + val partitions = relation.hiveQlPartitions.filter { part => + val partitionValues = part.getValues + var i = 0 + while (i < partitionValues.size()) { + inputData(i) = partitionValues(i) + i += 1 + } + pruningCondition(inputData) + } + + hiveContext + .parquetFile(partitions.map(_.getLocation).mkString(",")) + .addPartitioningAttributes(relation.partitionKeys) + .lowerCase + .where(unresolvedOtherPredicates) + .select(unresolvedProjection:_*) + .queryExecution + .executedPlan + .fakeOutput(projectList.map(_.toAttribute)):: Nil + } else { + hiveContext + .parquetFile(relation.hiveQlTable.getDataLocation.getPath) + .lowerCase + .where(unresolvedOtherPredicates) + .select(unresolvedProjection:_*) + .queryExecution + .executedPlan + .fakeOutput(projectList.map(_.toAttribute)) :: Nil + } + case _ => Nil + } + } + object Scripts extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case logical.ScriptTransformation(input, script, output, child) => diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/parquet/FakeParquetSerDe.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/parquet/FakeParquetSerDe.scala new file mode 100644 index 0000000000000..544abfc32423c --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/parquet/FakeParquetSerDe.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.parquet + +import java.util.Properties + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector.Category +import org.apache.hadoop.hive.serde2.{SerDeStats, SerDe} +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector +import org.apache.hadoop.io.Writable + +/** + * A placeholder that allows SparkSQL users to create metastore tables that are stored as + * parquet files. It is only intended to pass the checks that the serde is valid and exists + * when a CREATE TABLE is run. The actual work of decoding will be done by ParquetTableScan + * when "spark.sql.hive.convertMetastoreParquet" is set to true. + */ +@deprecated("No code should depend on FakeParquetHiveSerDe as it is only intended as a " + + "placeholder in the Hive MetaStore") +class FakeParquetSerDe extends SerDe { + override def getObjectInspector: ObjectInspector = new ObjectInspector { + override def getCategory: Category = Category.PRIMITIVE + + override def getTypeName: String = "string" + } + + override def deserialize(p1: Writable): AnyRef = throwError + + override def initialize(p1: Configuration, p2: Properties): Unit = {} + + override def getSerializedClass: Class[_ <: Writable] = throwError + + override def getSerDeStats: SerDeStats = throwError + + override def serialize(p1: scala.Any, p2: ObjectInspector): Writable = throwError + + private def throwError = + sys.error( + "spark.sql.hive.convertMetastoreParquet must be set to true to use FakeParquetSerDe") +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/parquet/ParquetMetastoreSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/parquet/ParquetMetastoreSuite.scala new file mode 100644 index 0000000000000..0723be7298e15 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/parquet/ParquetMetastoreSuite.scala @@ -0,0 +1,171 @@ + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.parquet + +import java.io.File + +import org.apache.spark.sql.hive.execution.HiveTableScan +import org.scalatest.BeforeAndAfterAll + +import scala.reflect.ClassTag + +import org.apache.spark.sql.{SQLConf, QueryTest} +import org.apache.spark.sql.execution.{BroadcastHashJoin, ShuffledHashJoin} +import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.hive.test.TestHive._ + +case class ParquetData(intField: Int, stringField: String) + +/** + * Tests for our SerDe -> Native parquet scan conversion. + */ +class ParquetMetastoreSuite extends QueryTest with BeforeAndAfterAll { + + override def beforeAll(): Unit = { + setConf("spark.sql.hive.convertMetastoreParquet", "true") + } + + override def afterAll(): Unit = { + setConf("spark.sql.hive.convertMetastoreParquet", "false") + } + + val partitionedTableDir = File.createTempFile("parquettests", "sparksql") + partitionedTableDir.delete() + partitionedTableDir.mkdir() + + (1 to 10).foreach { p => + val partDir = new File(partitionedTableDir, s"p=$p") + sparkContext.makeRDD(1 to 10) + .map(i => ParquetData(i, s"part-$p")) + .saveAsParquetFile(partDir.getCanonicalPath) + } + + sql(s""" + create external table partitioned_parquet + ( + intField INT, + stringField STRING + ) + PARTITIONED BY (p int) + ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' + STORED AS + INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' + OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' + location '${partitionedTableDir.getCanonicalPath}' + """) + + sql(s""" + create external table normal_parquet + ( + intField INT, + stringField STRING + ) + ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' + STORED AS + INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' + OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' + location '${new File(partitionedTableDir, "p=1").getCanonicalPath}' + """) + + (1 to 10).foreach { p => + sql(s"ALTER TABLE partitioned_parquet ADD PARTITION (p=$p)") + } + + test("project the partitioning column") { + checkAnswer( + sql("SELECT p, count(*) FROM partitioned_parquet group by p"), + (1, 10) :: + (2, 10) :: + (3, 10) :: + (4, 10) :: + (5, 10) :: + (6, 10) :: + (7, 10) :: + (8, 10) :: + (9, 10) :: + (10, 10) :: Nil + ) + } + + test("project partitioning and non-partitioning columns") { + checkAnswer( + sql("SELECT stringField, p, count(intField) " + + "FROM partitioned_parquet GROUP BY p, stringField"), + ("part-1", 1, 10) :: + ("part-2", 2, 10) :: + ("part-3", 3, 10) :: + ("part-4", 4, 10) :: + ("part-5", 5, 10) :: + ("part-6", 6, 10) :: + ("part-7", 7, 10) :: + ("part-8", 8, 10) :: + ("part-9", 9, 10) :: + ("part-10", 10, 10) :: Nil + ) + } + + test("simple count") { + checkAnswer( + sql("SELECT COUNT(*) FROM partitioned_parquet"), + 100) + } + + test("pruned count") { + checkAnswer( + sql("SELECT COUNT(*) FROM partitioned_parquet WHERE p = 1"), + 10) + } + + test("multi-partition pruned count") { + checkAnswer( + sql("SELECT COUNT(*) FROM partitioned_parquet WHERE p IN (1,2,3)"), + 30) + } + + test("non-partition predicates") { + checkAnswer( + sql("SELECT COUNT(*) FROM partitioned_parquet WHERE intField IN (1,2,3)"), + 30) + } + + test("sum") { + checkAnswer( + sql("SELECT SUM(intField) FROM partitioned_parquet WHERE intField IN (1,2,3) AND p = 1"), + 1 + 2 + 3 + ) + } + + test("non-part select(*)") { + checkAnswer( + sql("SELECT COUNT(*) FROM normal_parquet"), + 10 + ) + } + + test("conversion is working") { + assert( + sql("SELECT * FROM normal_parquet").queryExecution.executedPlan.collect { + case _: HiveTableScan => true + }.isEmpty) + assert( + sql("SELECT * FROM normal_parquet").queryExecution.executedPlan.collect { + case _: ParquetTableScan => true + }.nonEmpty) + } +} From e083334634ca0d7a25dee864fb2b9558ee92a2f7 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 18 Aug 2014 13:58:35 -0700 Subject: [PATCH 182/231] [SPARK-3103] [PySpark] fix saveAsTextFile() with utf-8 bugfix: It will raise an exception when it try to encode non-ASCII strings into unicode. It should only encode unicode as "utf-8". Author: Davies Liu Closes #2018 from davies/fix_utf8 and squashes the following commits: 4db7967 [Davies Liu] fix saveAsTextFile() with utf-8 (cherry picked from commit d1d0ee41c27f1d07fed0c5d56ba26c723cc3dc26) Signed-off-by: Josh Rosen --- python/pyspark/rdd.py | 4 +++- python/pyspark/tests.py | 9 +++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 240381e5bae12..c708b69cc1e31 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -1191,7 +1191,9 @@ def func(split, iterator): for x in iterator: if not isinstance(x, basestring): x = unicode(x) - yield x.encode("utf-8") + if isinstance(x, unicode): + x = x.encode("utf-8") + yield x keyed = self.mapPartitionsWithIndex(func) keyed._bypass_serializer = True keyed._jrdd.map(self.ctx._jvm.BytesToString()).saveAsTextFile(path) diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index f1fece998cd54..69d543d9d045d 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -256,6 +256,15 @@ def test_save_as_textfile_with_unicode(self): raw_contents = ''.join(input(glob(tempFile.name + "/part-0000*"))) self.assertEqual(x, unicode(raw_contents.strip(), "utf-8")) + def test_save_as_textfile_with_utf8(self): + x = u"\u00A1Hola, mundo!" + data = self.sc.parallelize([x.encode("utf-8")]) + tempFile = tempfile.NamedTemporaryFile(delete=True) + tempFile.close() + data.saveAsTextFile(tempFile.name) + raw_contents = ''.join(input(glob(tempFile.name + "/part-0000*"))) + self.assertEqual(x, unicode(raw_contents.strip(), "utf-8")) + def test_transforming_cartesian_result(self): # Regression test for SPARK-1034 rdd1 = self.sc.parallelize([1, 2]) From 25cabd7eec6e499fce94bce0d45087e9d8726a50 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Mon, 18 Aug 2014 14:10:10 -0700 Subject: [PATCH 183/231] [SPARK-2718] [yarn] Handle quotes and other characters in user args. Due to the way Yarn runs things through bash, normal quoting doesn't work as expected. This change applies the necessary voodoo to the user args to avoid issues with bash and special characters. The change also uncovered an issue with the event logger app name sanitizing code; it wasn't cleaning up all "bad" characters, so sometimes it would fail to create the log dirs. I just added some more bad character replacements. Author: Marcelo Vanzin Closes #1724 from vanzin/SPARK-2718 and squashes the following commits: cc84b89 [Marcelo Vanzin] Review feedback. c1a257a [Marcelo Vanzin] Add test for backslashes. 55571d4 [Marcelo Vanzin] Unbreak yarn-client. 515613d [Marcelo Vanzin] [SPARK-2718] [yarn] Handle quotes and other characters in user args. (cherry picked from commit 6201b27643023569e19b68aa9d5c4e4e59ce0d79) Signed-off-by: Andrew Or --- .../scheduler/EventLoggingListener.scala | 3 +- .../yarn/ApplicationMasterArguments.scala | 6 +- .../apache/spark/deploy/yarn/ClientBase.scala | 9 +-- .../deploy/yarn/ExecutorRunnableUtil.scala | 4 +- .../deploy/yarn/YarnSparkHadoopUtil.scala | 25 ++++++++ .../yarn/YarnSparkHadoopUtilSuite.scala | 64 +++++++++++++++++++ 6 files changed, 101 insertions(+), 10 deletions(-) create mode 100644 yarn/common/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala diff --git a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala index 7378ce923f0ae..370fcd85aa680 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala @@ -54,7 +54,8 @@ private[spark] class EventLoggingListener( private val testing = sparkConf.getBoolean("spark.eventLog.testing", false) private val outputBufferSize = sparkConf.getInt("spark.eventLog.buffer.kb", 100) * 1024 private val logBaseDir = sparkConf.get("spark.eventLog.dir", DEFAULT_LOG_DIR).stripSuffix("/") - private val name = appName.replaceAll("[ :/]", "-").toLowerCase + "-" + System.currentTimeMillis + private val name = appName.replaceAll("[ :/]", "-").replaceAll("[${}'\"]", "_") + .toLowerCase + "-" + System.currentTimeMillis val logDir = Utils.resolveURI(logBaseDir) + "/" + name.stripSuffix("/") protected val logger = new FileLogger(logDir, sparkConf, hadoopConf, outputBufferSize, diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala index 4c383ab574abe..424b0fb0936f2 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala @@ -29,7 +29,7 @@ class ApplicationMasterArguments(val args: Array[String]) { var numExecutors = ApplicationMasterArguments.DEFAULT_NUMBER_EXECUTORS parseArgs(args.toList) - + private def parseArgs(inputArgs: List[String]): Unit = { val userArgsBuffer = new ArrayBuffer[String]() @@ -47,7 +47,7 @@ class ApplicationMasterArguments(val args: Array[String]) { userClass = value args = tail - case ("--args") :: value :: tail => + case ("--args" | "--arg") :: value :: tail => userArgsBuffer += value args = tail @@ -75,7 +75,7 @@ class ApplicationMasterArguments(val args: Array[String]) { userArgs = userArgsBuffer.readOnly } - + def printUsageAndExit(exitCode: Int, unknownParam: Any = null) { if (unknownParam != null) { System.err.println("Unknown/unsupported param " + unknownParam) diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala index 1da0a1b675554..3897b3a373a8c 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala @@ -300,11 +300,11 @@ trait ClientBase extends Logging { } def userArgsToString(clientArgs: ClientArguments): String = { - val prefix = " --args " + val prefix = " --arg " val args = clientArgs.userArgs val retval = new StringBuilder() for (arg <- args) { - retval.append(prefix).append(" '").append(arg).append("' ") + retval.append(prefix).append(" ").append(YarnSparkHadoopUtil.escapeForShell(arg)) } retval.toString } @@ -386,7 +386,7 @@ trait ClientBase extends Logging { // TODO: it might be nicer to pass these as an internal environment variable rather than // as Java options, due to complications with string parsing of nested quotes. for ((k, v) <- sparkConf.getAll) { - javaOpts += "-D" + k + "=" + "\\\"" + v + "\\\"" + javaOpts += YarnSparkHadoopUtil.escapeForShell(s"-D$k=$v") } if (args.amClass == classOf[ApplicationMaster].getName) { @@ -400,7 +400,8 @@ trait ClientBase extends Logging { // Command for the ApplicationMaster val commands = Seq(Environment.JAVA_HOME.$() + "/bin/java", "-server") ++ javaOpts ++ - Seq(args.amClass, "--class", args.userClass, "--jar ", args.userJar, + Seq(args.amClass, "--class", YarnSparkHadoopUtil.escapeForShell(args.userClass), + "--jar ", YarnSparkHadoopUtil.escapeForShell(args.userJar), userArgsToString(args), "--executor-memory", args.executorMemory.toString, "--executor-cores", args.executorCores.toString, diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnableUtil.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnableUtil.scala index 71a9e42846b2b..312d82a649792 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnableUtil.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnableUtil.scala @@ -68,10 +68,10 @@ trait ExecutorRunnableUtil extends Logging { // authentication settings. sparkConf.getAll. filter { case (k, v) => k.startsWith("spark.auth") || k.startsWith("spark.akka") }. - foreach { case (k, v) => javaOpts += "-D" + k + "=" + "\\\"" + v + "\\\"" } + foreach { case (k, v) => javaOpts += YarnSparkHadoopUtil.escapeForShell(s"-D$k=$v") } sparkConf.getAkkaConf. - foreach { case (k, v) => javaOpts += "-D" + k + "=" + "\\\"" + v + "\\\"" } + foreach { case (k, v) => javaOpts += YarnSparkHadoopUtil.escapeForShell(s"-D$k=$v") } // Commenting it out for now - so that people can refer to the properties if required. Remove // it once cpuset version is pushed out. diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala index e98308cdbd74e..10aef5eb2486f 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala @@ -148,4 +148,29 @@ object YarnSparkHadoopUtil { } } + /** + * Escapes a string for inclusion in a command line executed by Yarn. Yarn executes commands + * using `bash -c "command arg1 arg2"` and that means plain quoting doesn't really work. The + * argument is enclosed in single quotes and some key characters are escaped. + * + * @param arg A single argument. + * @return Argument quoted for execution via Yarn's generated shell script. + */ + def escapeForShell(arg: String): String = { + if (arg != null) { + val escaped = new StringBuilder("'") + for (i <- 0 to arg.length() - 1) { + arg.charAt(i) match { + case '$' => escaped.append("\\$") + case '"' => escaped.append("\\\"") + case '\'' => escaped.append("'\\''") + case c => escaped.append(c) + } + } + escaped.append("'").toString() + } else { + arg + } + } + } diff --git a/yarn/common/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala b/yarn/common/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala new file mode 100644 index 0000000000000..7650bd4396c12 --- /dev/null +++ b/yarn/common/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.yarn + +import java.io.{File, IOException} + +import com.google.common.io.{ByteStreams, Files} +import org.scalatest.{FunSuite, Matchers} + +import org.apache.spark.Logging + +class YarnSparkHadoopUtilSuite extends FunSuite with Matchers with Logging { + + val hasBash = + try { + val exitCode = Runtime.getRuntime().exec(Array("bash", "--version")).waitFor() + exitCode == 0 + } catch { + case e: IOException => + false + } + + if (!hasBash) { + logWarning("Cannot execute bash, skipping bash tests.") + } + + def bashTest(name: String)(fn: => Unit) = + if (hasBash) test(name)(fn) else ignore(name)(fn) + + bashTest("shell script escaping") { + val scriptFile = File.createTempFile("script.", ".sh") + val args = Array("arg1", "${arg.2}", "\"arg3\"", "'arg4'", "$arg5", "\\arg6") + try { + val argLine = args.map(a => YarnSparkHadoopUtil.escapeForShell(a)).mkString(" ") + Files.write(("bash -c \"echo " + argLine + "\"").getBytes(), scriptFile) + scriptFile.setExecutable(true) + + val proc = Runtime.getRuntime().exec(Array(scriptFile.getAbsolutePath())) + val out = new String(ByteStreams.toByteArray(proc.getInputStream())).trim() + val err = new String(ByteStreams.toByteArray(proc.getErrorStream())) + val exitCode = proc.waitFor() + exitCode should be (0) + out should be (args.mkString(" ")) + } finally { + scriptFile.delete() + } + } + +} From 98778fffdb4e11593149eb7770071a0728653f19 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Mon, 18 Aug 2014 14:40:05 -0700 Subject: [PATCH 184/231] [mllib] DecisionTree: treeAggregate + Python example bug fix Small DecisionTree updates: * Changed main DecisionTree aggregate to treeAggregate. * Fixed bug in python example decision_tree_runner.py with missing argument (since categoricalFeaturesInfo is no longer an optional argument for trainClassifier). * Fixed same bug in python doc tests, and added tree.py to doc tests. CC: mengxr Author: Joseph K. Bradley Closes #2015 from jkbradley/dt-opt2 and squashes the following commits: b5114fa [Joseph K. Bradley] Fixed python tree.py doc test (extra newline) 8e4665d [Joseph K. Bradley] Added tree.py to python doc tests. Fixed bug from missing categoricalFeaturesInfo argument. b7b2922 [Joseph K. Bradley] Fixed bug in python example decision_tree_runner.py with missing argument. Changed main DecisionTree aggregate to treeAggregate. 85bbc1f [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-opt2 66d076f [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-opt2 a0ed0da [Joseph K. Bradley] Renamed DTMetadata to DecisionTreeMetadata. Small doc updates. 3726d20 [Joseph K. Bradley] Small code improvements based on code review. ac0b9f8 [Joseph K. Bradley] Small updates based on code review. Main change: Now using << instead of math.pow. db0d773 [Joseph K. Bradley] scala style fix 6a38f48 [Joseph K. Bradley] Added DTMetadata class for cleaner code 931a3a7 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-opt2 797f68a [Joseph K. Bradley] Fixed DecisionTreeSuite bug for training second level. Needed to update treePointToNodeIndex with groupShift. f40381c [Joseph K. Bradley] Merge branch 'dt-opt1' into dt-opt2 5f2dec2 [Joseph K. Bradley] Fixed scalastyle issue in TreePoint 6b5651e [Joseph K. Bradley] Updates based on code review. 1 major change: persisting to memory + disk, not just memory. 2d2aaaf [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-opt1 26d10dd [Joseph K. Bradley] Removed tree/model/Filter.scala since no longer used. Removed debugging println calls in DecisionTree.scala. 356daba [Joseph K. Bradley] Merge branch 'dt-opt1' into dt-opt2 430d782 [Joseph K. Bradley] Added more debug info on binning error. Added some docs. d036089 [Joseph K. Bradley] Print timing info to logDebug. e66f1b1 [Joseph K. Bradley] TreePoint * Updated doc * Made some methods private 8464a6e [Joseph K. Bradley] Moved TimeTracker to tree/impl/ in its own file, and cleaned it up. Removed debugging println calls from DecisionTree. Made TreePoint extend Serialiable a87e08f [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-opt1 c1565a5 [Joseph K. Bradley] Small DecisionTree updates: * Simplification: Updated calculateGainForSplit to take aggregates for a single (feature, split) pair. * Internal doc: findAggForOrderedFeatureClassification b914f3b [Joseph K. Bradley] DecisionTree optimization: eliminated filters + small changes b2ed1f3 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-opt 0f676e2 [Joseph K. Bradley] Optimizations + Bug fix for DecisionTree 3211f02 [Joseph K. Bradley] Optimizing DecisionTree * Added TreePoint representation to avoid calling findBin multiple times. * (not working yet, but debugging) f61e9d2 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-timing bcf874a [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-timing 511ec85 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-timing a95bc22 [Joseph K. Bradley] timing for DecisionTree internals (cherry picked from commit 115eeb30dd9c9dd10685a71f2c23ca23794d3142) Signed-off-by: Xiangrui Meng --- .../src/main/python/mllib/decision_tree_runner.py | 4 +++- .../org/apache/spark/mllib/tree/DecisionTree.scala | 3 ++- python/pyspark/mllib/tree.py | 14 ++++++++------ python/run-tests | 1 + 4 files changed, 14 insertions(+), 8 deletions(-) diff --git a/examples/src/main/python/mllib/decision_tree_runner.py b/examples/src/main/python/mllib/decision_tree_runner.py index 8efadb5223f56..db96a7cb3730f 100755 --- a/examples/src/main/python/mllib/decision_tree_runner.py +++ b/examples/src/main/python/mllib/decision_tree_runner.py @@ -124,7 +124,9 @@ def usage(): (reindexedData, origToNewLabels) = reindexClassLabels(points) # Train a classifier. - model = DecisionTree.trainClassifier(reindexedData, numClasses=2) + categoricalFeaturesInfo={} # no categorical features + model = DecisionTree.trainClassifier(reindexedData, numClasses=2, + categoricalFeaturesInfo=categoricalFeaturesInfo) # Print learned tree and stats. print "Trained DecisionTree for classification:" print " Model numNodes: %d\n" % model.numNodes() diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 6b9a8f72c244e..5cdd258f6c20b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -22,6 +22,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaRDD import org.apache.spark.Logging +import org.apache.spark.mllib.rdd.RDDFunctions._ import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.Strategy import org.apache.spark.mllib.tree.configuration.Algo._ @@ -826,7 +827,7 @@ object DecisionTree extends Serializable with Logging { // Calculate bin aggregates. timer.start("aggregation") val binAggregates = { - input.aggregate(Array.fill[Double](binAggregateLength)(0))(binSeqOp, binCombOp) + input.treeAggregate(Array.fill[Double](binAggregateLength)(0))(binSeqOp, binCombOp) } timer.stop("aggregation") logDebug("binAggregates.length = " + binAggregates.length) diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py index e1a4671709b7d..e9d778df5a24b 100644 --- a/python/pyspark/mllib/tree.py +++ b/python/pyspark/mllib/tree.py @@ -88,7 +88,8 @@ class DecisionTree(object): It will probably be modified for Spark v1.2. Example usage: - >>> from numpy import array, ndarray + >>> from numpy import array + >>> import sys >>> from pyspark.mllib.regression import LabeledPoint >>> from pyspark.mllib.tree import DecisionTree >>> from pyspark.mllib.linalg import SparseVector @@ -99,15 +100,15 @@ class DecisionTree(object): ... LabeledPoint(1.0, [2.0]), ... LabeledPoint(1.0, [3.0]) ... ] - >>> - >>> model = DecisionTree.trainClassifier(sc.parallelize(data), numClasses=2) - >>> print(model) + >>> categoricalFeaturesInfo = {} # no categorical features + >>> model = DecisionTree.trainClassifier(sc.parallelize(data), numClasses=2, + ... categoricalFeaturesInfo=categoricalFeaturesInfo) + >>> sys.stdout.write(model) DecisionTreeModel classifier If (feature 0 <= 0.5) Predict: 0.0 Else (feature 0 > 0.5) Predict: 1.0 - >>> model.predict(array([1.0])) > 0 True >>> model.predict(array([0.0])) == 0 @@ -119,7 +120,8 @@ class DecisionTree(object): ... LabeledPoint(1.0, SparseVector(2, {1: 2.0})) ... ] >>> - >>> model = DecisionTree.trainRegressor(sc.parallelize(sparse_data)) + >>> model = DecisionTree.trainRegressor(sc.parallelize(sparse_data), + ... categoricalFeaturesInfo=categoricalFeaturesInfo) >>> model.predict(array([0.0, 1.0])) == 1 True >>> model.predict(array([0.0, 0.0])) == 0 diff --git a/python/run-tests b/python/run-tests index 1218edcbd7e08..a6271e0cf5fa9 100755 --- a/python/run-tests +++ b/python/run-tests @@ -79,6 +79,7 @@ run_test "pyspark/mllib/random.py" run_test "pyspark/mllib/recommendation.py" run_test "pyspark/mllib/regression.py" run_test "pyspark/mllib/tests.py" +run_test "pyspark/mllib/tree.py" run_test "pyspark/mllib/util.py" if [[ $FAILED == 0 ]]; then From e3f89e971b117e11d15e4b9b47e63da55f4e488b Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Mon, 18 Aug 2014 18:01:39 -0700 Subject: [PATCH 185/231] [SPARK-2850] [SPARK-2626] [mllib] MLlib stats examples + small fixes Added examples for statistical summarization: * Scala: StatisticalSummary.scala ** Tests: correlation, MultivariateOnlineSummarizer * python: statistical_summary.py ** Tests: correlation (since MultivariateOnlineSummarizer has no Python API) Added examples for random and sampled RDDs: * Scala: RandomAndSampledRDDs.scala * python: random_and_sampled_rdds.py * Both test: ** RandomRDDGenerators.normalRDD, normalVectorRDD ** RDD.sample, takeSample, sampleByKey Added sc.stop() to all examples. CorrelationSuite.scala * Added 1 test for RDDs with only 1 value RowMatrix.scala * numCols(): Added check for numRows = 0, with error message. * computeCovariance(): Added check for numRows <= 1, with error message. Python SparseVector (pyspark/mllib/linalg.py) * Added toDense() function python/run-tests script * Added stat.py (doc test) CC: mengxr dorx Main changes were examples to show usage across APIs. Author: Joseph K. Bradley Closes #1878 from jkbradley/mllib-stats-api-check and squashes the following commits: ea5c047 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into mllib-stats-api-check dafebe2 [Joseph K. Bradley] Bug fixes for examples SampledRDDs.scala and sampled_rdds.py: Check for division by 0 and for missing key in maps. 8d1e555 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into mllib-stats-api-check 60c72d9 [Joseph K. Bradley] Fixed stat.py doc test to work for Python versions printing nan or NaN. b20d90a [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into mllib-stats-api-check 4e5d15e [Joseph K. Bradley] Changed pyspark/mllib/stat.py doc tests to use NaN instead of nan. 32173b7 [Joseph K. Bradley] Stats examples update. c8c20dc [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into mllib-stats-api-check cf70b07 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into mllib-stats-api-check 0b7cec3 [Joseph K. Bradley] Small updates based on code review. Renamed statistical_summary.py to correlations.py ab48f6e [Joseph K. Bradley] RowMatrix.scala * numCols(): Added check for numRows = 0, with error message. * computeCovariance(): Added check for numRows <= 1, with error message. 65e4ebc [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into mllib-stats-api-check 8195c78 [Joseph K. Bradley] Added examples for random and sampled RDDs: * Scala: RandomAndSampledRDDs.scala * python: random_and_sampled_rdds.py * Both test: ** RandomRDDGenerators.normalRDD, normalVectorRDD ** RDD.sample, takeSample, sampleByKey 064985b [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into mllib-stats-api-check ee918e9 [Joseph K. Bradley] Added examples for statistical summarization: * Scala: StatisticalSummary.scala ** Tests: correlation, MultivariateOnlineSummarizer * python: statistical_summary.py ** Tests: correlation (since MultivariateOnlineSummarizer has no Python API) (cherry picked from commit c8b16ca0d86cc60fb960eebf0cb383f159a88b03) Signed-off-by: Xiangrui Meng --- examples/src/main/python/als.py | 2 + .../src/main/python/cassandra_inputformat.py | 2 + .../src/main/python/cassandra_outputformat.py | 2 + examples/src/main/python/hbase_inputformat.py | 2 + .../src/main/python/hbase_outputformat.py | 2 + examples/src/main/python/kmeans.py | 2 + .../src/main/python/logistic_regression.py | 2 + .../src/main/python/mllib/correlations.py | 60 +++++++++ .../main/python/mllib/decision_tree_runner.py | 5 + examples/src/main/python/mllib/kmeans.py | 1 + .../main/python/mllib/logistic_regression.py | 1 + .../python/mllib/random_rdd_generation.py | 55 ++++++++ .../src/main/python/mllib/sampled_rdds.py | 86 ++++++++++++ examples/src/main/python/pagerank.py | 2 + examples/src/main/python/pi.py | 2 + examples/src/main/python/sort.py | 2 + .../src/main/python/transitive_closure.py | 2 + examples/src/main/python/wordcount.py | 2 + .../spark/examples/mllib/Correlations.scala | 92 +++++++++++++ .../mllib/MultivariateSummarizer.scala | 98 ++++++++++++++ .../examples/mllib/RandomRDDGeneration.scala | 60 +++++++++ .../spark/examples/mllib/SampledRDDs.scala | 126 ++++++++++++++++++ .../mllib/linalg/distributed/RowMatrix.scala | 14 +- .../stat/MultivariateOnlineSummarizer.scala | 8 +- .../spark/mllib/stat/CorrelationSuite.scala | 15 ++- .../MultivariateOnlineSummarizerSuite.scala | 6 +- python/pyspark/mllib/linalg.py | 10 ++ python/pyspark/mllib/stat.py | 22 +-- python/run-tests | 1 + 29 files changed, 664 insertions(+), 20 deletions(-) create mode 100755 examples/src/main/python/mllib/correlations.py create mode 100755 examples/src/main/python/mllib/random_rdd_generation.py create mode 100755 examples/src/main/python/mllib/sampled_rdds.py create mode 100644 examples/src/main/scala/org/apache/spark/examples/mllib/Correlations.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/mllib/MultivariateSummarizer.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/mllib/RandomRDDGeneration.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/mllib/SampledRDDs.scala diff --git a/examples/src/main/python/als.py b/examples/src/main/python/als.py index c862650b0aa1d..5b1fa4d997eeb 100755 --- a/examples/src/main/python/als.py +++ b/examples/src/main/python/als.py @@ -97,3 +97,5 @@ def update(i, vec, mat, ratings): error = rmse(R, ms, us) print "Iteration %d:" % i print "\nRMSE: %5.4f\n" % error + + sc.stop() diff --git a/examples/src/main/python/cassandra_inputformat.py b/examples/src/main/python/cassandra_inputformat.py index 39fa6b0d22ef5..e4a897f61e39d 100644 --- a/examples/src/main/python/cassandra_inputformat.py +++ b/examples/src/main/python/cassandra_inputformat.py @@ -77,3 +77,5 @@ output = cass_rdd.collect() for (k, v) in output: print (k, v) + + sc.stop() diff --git a/examples/src/main/python/cassandra_outputformat.py b/examples/src/main/python/cassandra_outputformat.py index 1dfbf98604425..836c35b5c6794 100644 --- a/examples/src/main/python/cassandra_outputformat.py +++ b/examples/src/main/python/cassandra_outputformat.py @@ -81,3 +81,5 @@ conf=conf, keyConverter="org.apache.spark.examples.pythonconverters.ToCassandraCQLKeyConverter", valueConverter="org.apache.spark.examples.pythonconverters.ToCassandraCQLValueConverter") + + sc.stop() diff --git a/examples/src/main/python/hbase_inputformat.py b/examples/src/main/python/hbase_inputformat.py index c9fa8e171c2a1..befacee0dea56 100644 --- a/examples/src/main/python/hbase_inputformat.py +++ b/examples/src/main/python/hbase_inputformat.py @@ -71,3 +71,5 @@ output = hbase_rdd.collect() for (k, v) in output: print (k, v) + + sc.stop() diff --git a/examples/src/main/python/hbase_outputformat.py b/examples/src/main/python/hbase_outputformat.py index 5e11548fd13f7..49bbc5aebdb0b 100644 --- a/examples/src/main/python/hbase_outputformat.py +++ b/examples/src/main/python/hbase_outputformat.py @@ -63,3 +63,5 @@ conf=conf, keyConverter="org.apache.spark.examples.pythonconverters.StringToImmutableBytesWritableConverter", valueConverter="org.apache.spark.examples.pythonconverters.StringListToPutConverter") + + sc.stop() diff --git a/examples/src/main/python/kmeans.py b/examples/src/main/python/kmeans.py index 036bdf4c4f999..86ef6f32c84e8 100755 --- a/examples/src/main/python/kmeans.py +++ b/examples/src/main/python/kmeans.py @@ -77,3 +77,5 @@ def closestPoint(p, centers): kPoints[x] = y print "Final centers: " + str(kPoints) + + sc.stop() diff --git a/examples/src/main/python/logistic_regression.py b/examples/src/main/python/logistic_regression.py index 8456b272f9c05..3aa56b0528168 100755 --- a/examples/src/main/python/logistic_regression.py +++ b/examples/src/main/python/logistic_regression.py @@ -80,3 +80,5 @@ def add(x, y): w -= points.map(lambda m: gradient(m, w)).reduce(add) print "Final w: " + str(w) + + sc.stop() diff --git a/examples/src/main/python/mllib/correlations.py b/examples/src/main/python/mllib/correlations.py new file mode 100755 index 0000000000000..6b16a56e44af7 --- /dev/null +++ b/examples/src/main/python/mllib/correlations.py @@ -0,0 +1,60 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Correlations using MLlib. +""" + +import sys + +from pyspark import SparkContext +from pyspark.mllib.regression import LabeledPoint +from pyspark.mllib.stat import Statistics +from pyspark.mllib.util import MLUtils + + +if __name__ == "__main__": + if len(sys.argv) not in [1,2]: + print >> sys.stderr, "Usage: correlations ()" + exit(-1) + sc = SparkContext(appName="PythonCorrelations") + if len(sys.argv) == 2: + filepath = sys.argv[1] + else: + filepath = 'data/mllib/sample_linear_regression_data.txt' + corrType = 'pearson' + + points = MLUtils.loadLibSVMFile(sc, filepath)\ + .map(lambda lp: LabeledPoint(lp.label, lp.features.toArray())) + + print + print 'Summary of data file: ' + filepath + print '%d data points' % points.count() + + # Statistics (correlations) + print + print 'Correlation (%s) between label and each feature' % corrType + print 'Feature\tCorrelation' + numFeatures = points.take(1)[0].features.size + labelRDD = points.map(lambda lp: lp.label) + for i in range(numFeatures): + featureRDD = points.map(lambda lp: lp.features[i]) + corr = Statistics.corr(labelRDD, featureRDD, corrType) + print '%d\t%g' % (i, corr) + print + + sc.stop() diff --git a/examples/src/main/python/mllib/decision_tree_runner.py b/examples/src/main/python/mllib/decision_tree_runner.py index db96a7cb3730f..6e4a4a0cb6be0 100755 --- a/examples/src/main/python/mllib/decision_tree_runner.py +++ b/examples/src/main/python/mllib/decision_tree_runner.py @@ -17,6 +17,8 @@ """ Decision tree classification and regression using MLlib. + +This example requires NumPy (http://www.numpy.org/). """ import numpy, os, sys @@ -117,6 +119,7 @@ def usage(): if len(sys.argv) == 2: dataPath = sys.argv[1] if not os.path.isfile(dataPath): + sc.stop() usage() points = MLUtils.loadLibSVMFile(sc, dataPath) @@ -133,3 +136,5 @@ def usage(): print " Model depth: %d\n" % model.depth() print " Training accuracy: %g\n" % getAccuracy(model, reindexedData) print model + + sc.stop() diff --git a/examples/src/main/python/mllib/kmeans.py b/examples/src/main/python/mllib/kmeans.py index b308132c9aeeb..2eeb1abeeb12b 100755 --- a/examples/src/main/python/mllib/kmeans.py +++ b/examples/src/main/python/mllib/kmeans.py @@ -42,3 +42,4 @@ def parseVector(line): k = int(sys.argv[2]) model = KMeans.train(data, k) print "Final centers: " + str(model.clusterCenters) + sc.stop() diff --git a/examples/src/main/python/mllib/logistic_regression.py b/examples/src/main/python/mllib/logistic_regression.py index 9d547ff77c984..8cae27fc4a52d 100755 --- a/examples/src/main/python/mllib/logistic_regression.py +++ b/examples/src/main/python/mllib/logistic_regression.py @@ -50,3 +50,4 @@ def parsePoint(line): model = LogisticRegressionWithSGD.train(points, iterations) print "Final weights: " + str(model.weights) print "Final intercept: " + str(model.intercept) + sc.stop() diff --git a/examples/src/main/python/mllib/random_rdd_generation.py b/examples/src/main/python/mllib/random_rdd_generation.py new file mode 100755 index 0000000000000..b388d8d83fb86 --- /dev/null +++ b/examples/src/main/python/mllib/random_rdd_generation.py @@ -0,0 +1,55 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Randomly generated RDDs. +""" + +import sys + +from pyspark import SparkContext +from pyspark.mllib.random import RandomRDDs + + +if __name__ == "__main__": + if len(sys.argv) not in [1, 2]: + print >> sys.stderr, "Usage: random_rdd_generation" + exit(-1) + + sc = SparkContext(appName="PythonRandomRDDGeneration") + + numExamples = 10000 # number of examples to generate + fraction = 0.1 # fraction of data to sample + + # Example: RandomRDDs.normalRDD + normalRDD = RandomRDDs.normalRDD(sc, numExamples) + print 'Generated RDD of %d examples sampled from the standard normal distribution'\ + % normalRDD.count() + print ' First 5 samples:' + for sample in normalRDD.take(5): + print ' ' + str(sample) + print + + # Example: RandomRDDs.normalVectorRDD + normalVectorRDD = RandomRDDs.normalVectorRDD(sc, numRows = numExamples, numCols = 2) + print 'Generated RDD of %d examples of length-2 vectors.' % normalVectorRDD.count() + print ' First 5 samples:' + for sample in normalVectorRDD.take(5): + print ' ' + str(sample) + print + + sc.stop() diff --git a/examples/src/main/python/mllib/sampled_rdds.py b/examples/src/main/python/mllib/sampled_rdds.py new file mode 100755 index 0000000000000..ec64a5978c672 --- /dev/null +++ b/examples/src/main/python/mllib/sampled_rdds.py @@ -0,0 +1,86 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Randomly sampled RDDs. +""" + +import sys + +from pyspark import SparkContext +from pyspark.mllib.util import MLUtils + + +if __name__ == "__main__": + if len(sys.argv) not in [1, 2]: + print >> sys.stderr, "Usage: sampled_rdds " + exit(-1) + if len(sys.argv) == 2: + datapath = sys.argv[1] + else: + datapath = 'data/mllib/sample_binary_classification_data.txt' + + sc = SparkContext(appName="PythonSampledRDDs") + + fraction = 0.1 # fraction of data to sample + + examples = MLUtils.loadLibSVMFile(sc, datapath) + numExamples = examples.count() + if numExamples == 0: + print >> sys.stderr, "Error: Data file had no samples to load." + exit(1) + print 'Loaded data with %d examples from file: %s' % (numExamples, datapath) + + # Example: RDD.sample() and RDD.takeSample() + expectedSampleSize = int(numExamples * fraction) + print 'Sampling RDD using fraction %g. Expected sample size = %d.' \ + % (fraction, expectedSampleSize) + sampledRDD = examples.sample(withReplacement = True, fraction = fraction) + print ' RDD.sample(): sample has %d examples' % sampledRDD.count() + sampledArray = examples.takeSample(withReplacement = True, num = expectedSampleSize) + print ' RDD.takeSample(): sample has %d examples' % len(sampledArray) + + print + + # Example: RDD.sampleByKey() + keyedRDD = examples.map(lambda lp: (int(lp.label), lp.features)) + print ' Keyed data using label (Int) as key ==> Orig' + # Count examples per label in original data. + keyCountsA = keyedRDD.countByKey() + + # Subsample, and count examples per label in sampled data. + fractions = {} + for k in keyCountsA.keys(): + fractions[k] = fraction + sampledByKeyRDD = keyedRDD.sampleByKey(withReplacement = True, fractions = fractions) + keyCountsB = sampledByKeyRDD.countByKey() + sizeB = sum(keyCountsB.values()) + print ' Sampled %d examples using approximate stratified sampling (by label). ==> Sample' \ + % sizeB + + # Compare samples + print ' \tFractions of examples with key' + print 'Key\tOrig\tSample' + for k in sorted(keyCountsA.keys()): + fracA = keyCountsA[k] / float(numExamples) + if sizeB != 0: + fracB = keyCountsB.get(k, 0) / float(sizeB) + else: + fracB = 0 + print '%d\t%g\t%g' % (k, fracA, fracB) + + sc.stop() diff --git a/examples/src/main/python/pagerank.py b/examples/src/main/python/pagerank.py index 0b96343158d44..b539c4128cdcc 100755 --- a/examples/src/main/python/pagerank.py +++ b/examples/src/main/python/pagerank.py @@ -68,3 +68,5 @@ def parseNeighbors(urls): # Collects all URL ranks and dump them to console. for (link, rank) in ranks.collect(): print "%s has rank: %s." % (link, rank) + + sc.stop() diff --git a/examples/src/main/python/pi.py b/examples/src/main/python/pi.py index 21d94a2cd4b64..fc37459dc74aa 100755 --- a/examples/src/main/python/pi.py +++ b/examples/src/main/python/pi.py @@ -37,3 +37,5 @@ def f(_): count = sc.parallelize(xrange(1, n+1), slices).map(f).reduce(add) print "Pi is roughly %f" % (4.0 * count / n) + + sc.stop() diff --git a/examples/src/main/python/sort.py b/examples/src/main/python/sort.py index 41d00c1b79133..bb686f17518a0 100755 --- a/examples/src/main/python/sort.py +++ b/examples/src/main/python/sort.py @@ -34,3 +34,5 @@ output = sortedCount.collect() for (num, unitcount) in output: print num + + sc.stop() diff --git a/examples/src/main/python/transitive_closure.py b/examples/src/main/python/transitive_closure.py index 8698369b13d84..bf331b542c438 100755 --- a/examples/src/main/python/transitive_closure.py +++ b/examples/src/main/python/transitive_closure.py @@ -64,3 +64,5 @@ def generateGraph(): break print "TC has %i edges" % tc.count() + + sc.stop() diff --git a/examples/src/main/python/wordcount.py b/examples/src/main/python/wordcount.py index dcc095fdd0ed9..ae6cd13b83d92 100755 --- a/examples/src/main/python/wordcount.py +++ b/examples/src/main/python/wordcount.py @@ -33,3 +33,5 @@ output = counts.collect() for (word, count) in output: print "%s: %i" % (word, count) + + sc.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/Correlations.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/Correlations.scala new file mode 100644 index 0000000000000..d6b2fe430e5a4 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/Correlations.scala @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib + +import scopt.OptionParser + +import org.apache.spark.mllib.stat.Statistics +import org.apache.spark.mllib.util.MLUtils +import org.apache.spark.{SparkConf, SparkContext} + + +/** + * An example app for summarizing multivariate data from a file. Run with + * {{{ + * bin/run-example org.apache.spark.examples.mllib.Correlations + * }}} + * By default, this loads a synthetic dataset from `data/mllib/sample_linear_regression_data.txt`. + * If you use it as a template to create your own app, please use `spark-submit` to submit your app. + */ +object Correlations { + + case class Params(input: String = "data/mllib/sample_linear_regression_data.txt") + + def main(args: Array[String]) { + + val defaultParams = Params() + + val parser = new OptionParser[Params]("Correlations") { + head("Correlations: an example app for computing correlations") + opt[String]("input") + .text(s"Input path to labeled examples in LIBSVM format, default: ${defaultParams.input}") + .action((x, c) => c.copy(input = x)) + note( + """ + |For example, the following command runs this app on a synthetic dataset: + | + | bin/spark-submit --class org.apache.spark.examples.mllib.Correlations \ + | examples/target/scala-*/spark-examples-*.jar \ + | --input data/mllib/sample_linear_regression_data.txt + """.stripMargin) + } + + parser.parse(args, defaultParams).map { params => + run(params) + } getOrElse { + sys.exit(1) + } + } + + def run(params: Params) { + val conf = new SparkConf().setAppName(s"Correlations with $params") + val sc = new SparkContext(conf) + + val examples = MLUtils.loadLibSVMFile(sc, params.input).cache() + + println(s"Summary of data file: ${params.input}") + println(s"${examples.count()} data points") + + // Calculate label -- feature correlations + val labelRDD = examples.map(_.label) + val numFeatures = examples.take(1)(0).features.size + val corrType = "pearson" + println() + println(s"Correlation ($corrType) between label and each feature") + println(s"Feature\tCorrelation") + var feature = 0 + while (feature < numFeatures) { + val featureRDD = examples.map(_.features(feature)) + val corr = Statistics.corr(labelRDD, featureRDD) + println(s"$feature\t$corr") + feature += 1 + } + println() + + sc.stop() + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/MultivariateSummarizer.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/MultivariateSummarizer.scala new file mode 100644 index 0000000000000..4532512c01f84 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/MultivariateSummarizer.scala @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib + +import scopt.OptionParser + +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer +import org.apache.spark.mllib.util.MLUtils +import org.apache.spark.{SparkConf, SparkContext} + + +/** + * An example app for summarizing multivariate data from a file. Run with + * {{{ + * bin/run-example org.apache.spark.examples.mllib.MultivariateSummarizer + * }}} + * By default, this loads a synthetic dataset from `data/mllib/sample_linear_regression_data.txt`. + * If you use it as a template to create your own app, please use `spark-submit` to submit your app. + */ +object MultivariateSummarizer { + + case class Params(input: String = "data/mllib/sample_linear_regression_data.txt") + + def main(args: Array[String]) { + + val defaultParams = Params() + + val parser = new OptionParser[Params]("MultivariateSummarizer") { + head("MultivariateSummarizer: an example app for MultivariateOnlineSummarizer") + opt[String]("input") + .text(s"Input path to labeled examples in LIBSVM format, default: ${defaultParams.input}") + .action((x, c) => c.copy(input = x)) + note( + """ + |For example, the following command runs this app on a synthetic dataset: + | + | bin/spark-submit --class org.apache.spark.examples.mllib.MultivariateSummarizer \ + | examples/target/scala-*/spark-examples-*.jar \ + | --input data/mllib/sample_linear_regression_data.txt + """.stripMargin) + } + + parser.parse(args, defaultParams).map { params => + run(params) + } getOrElse { + sys.exit(1) + } + } + + def run(params: Params) { + val conf = new SparkConf().setAppName(s"MultivariateSummarizer with $params") + val sc = new SparkContext(conf) + + val examples = MLUtils.loadLibSVMFile(sc, params.input).cache() + + println(s"Summary of data file: ${params.input}") + println(s"${examples.count()} data points") + + // Summarize labels + val labelSummary = examples.aggregate(new MultivariateOnlineSummarizer())( + (summary, lp) => summary.add(Vectors.dense(lp.label)), + (sum1, sum2) => sum1.merge(sum2)) + + // Summarize features + val featureSummary = examples.aggregate(new MultivariateOnlineSummarizer())( + (summary, lp) => summary.add(lp.features), + (sum1, sum2) => sum1.merge(sum2)) + + println() + println(s"Summary statistics") + println(s"\tLabel\tFeatures") + println(s"mean\t${labelSummary.mean(0)}\t${featureSummary.mean.toArray.mkString("\t")}") + println(s"var\t${labelSummary.variance(0)}\t${featureSummary.variance.toArray.mkString("\t")}") + println( + s"nnz\t${labelSummary.numNonzeros(0)}\t${featureSummary.numNonzeros.toArray.mkString("\t")}") + println(s"max\t${labelSummary.max(0)}\t${featureSummary.max.toArray.mkString("\t")}") + println(s"min\t${labelSummary.min(0)}\t${featureSummary.min.toArray.mkString("\t")}") + println() + + sc.stop() + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/RandomRDDGeneration.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/RandomRDDGeneration.scala new file mode 100644 index 0000000000000..924b586e3af99 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/RandomRDDGeneration.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib + +import org.apache.spark.mllib.random.RandomRDDs +import org.apache.spark.rdd.RDD + +import org.apache.spark.{SparkConf, SparkContext} + +/** + * An example app for randomly generated RDDs. Run with + * {{{ + * bin/run-example org.apache.spark.examples.mllib.RandomRDDGeneration + * }}} + * If you use it as a template to create your own app, please use `spark-submit` to submit your app. + */ +object RandomRDDGeneration { + + def main(args: Array[String]) { + + val conf = new SparkConf().setAppName(s"RandomRDDGeneration") + val sc = new SparkContext(conf) + + val numExamples = 10000 // number of examples to generate + val fraction = 0.1 // fraction of data to sample + + // Example: RandomRDDs.normalRDD + val normalRDD: RDD[Double] = RandomRDDs.normalRDD(sc, numExamples) + println(s"Generated RDD of ${normalRDD.count()}" + + " examples sampled from the standard normal distribution") + println(" First 5 samples:") + normalRDD.take(5).foreach( x => println(s" $x") ) + + // Example: RandomRDDs.normalVectorRDD + val normalVectorRDD = RandomRDDs.normalVectorRDD(sc, numRows = numExamples, numCols = 2) + println(s"Generated RDD of ${normalVectorRDD.count()} examples of length-2 vectors.") + println(" First 5 samples:") + normalVectorRDD.take(5).foreach( x => println(s" $x") ) + + println() + + sc.stop() + } + +} diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/SampledRDDs.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/SampledRDDs.scala new file mode 100644 index 0000000000000..f01b8266e3fe3 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/SampledRDDs.scala @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib + +import org.apache.spark.mllib.util.MLUtils +import scopt.OptionParser + +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.SparkContext._ + +/** + * An example app for randomly generated and sampled RDDs. Run with + * {{{ + * bin/run-example org.apache.spark.examples.mllib.SampledRDDs + * }}} + * If you use it as a template to create your own app, please use `spark-submit` to submit your app. + */ +object SampledRDDs { + + case class Params(input: String = "data/mllib/sample_binary_classification_data.txt") + + def main(args: Array[String]) { + val defaultParams = Params() + + val parser = new OptionParser[Params]("SampledRDDs") { + head("SampledRDDs: an example app for randomly generated and sampled RDDs.") + opt[String]("input") + .text(s"Input path to labeled examples in LIBSVM format, default: ${defaultParams.input}") + .action((x, c) => c.copy(input = x)) + note( + """ + |For example, the following command runs this app: + | + | bin/spark-submit --class org.apache.spark.examples.mllib.SampledRDDs \ + | examples/target/scala-*/spark-examples-*.jar + """.stripMargin) + } + + parser.parse(args, defaultParams).map { params => + run(params) + } getOrElse { + sys.exit(1) + } + } + + def run(params: Params) { + val conf = new SparkConf().setAppName(s"SampledRDDs with $params") + val sc = new SparkContext(conf) + + val fraction = 0.1 // fraction of data to sample + + val examples = MLUtils.loadLibSVMFile(sc, params.input) + val numExamples = examples.count() + if (numExamples == 0) { + throw new RuntimeException("Error: Data file had no samples to load.") + } + println(s"Loaded data with $numExamples examples from file: ${params.input}") + + // Example: RDD.sample() and RDD.takeSample() + val expectedSampleSize = (numExamples * fraction).toInt + println(s"Sampling RDD using fraction $fraction. Expected sample size = $expectedSampleSize.") + val sampledRDD = examples.sample(withReplacement = true, fraction = fraction) + println(s" RDD.sample(): sample has ${sampledRDD.count()} examples") + val sampledArray = examples.takeSample(withReplacement = true, num = expectedSampleSize) + println(s" RDD.takeSample(): sample has ${sampledArray.size} examples") + + println() + + // Example: RDD.sampleByKey() and RDD.sampleByKeyExact() + val keyedRDD = examples.map { lp => (lp.label.toInt, lp.features) } + println(s" Keyed data using label (Int) as key ==> Orig") + // Count examples per label in original data. + val keyCounts = keyedRDD.countByKey() + + // Subsample, and count examples per label in sampled data. (approximate) + val fractions = keyCounts.keys.map((_, fraction)).toMap + val sampledByKeyRDD = keyedRDD.sampleByKey(withReplacement = true, fractions = fractions) + val keyCountsB = sampledByKeyRDD.countByKey() + val sizeB = keyCountsB.values.sum + println(s" Sampled $sizeB examples using approximate stratified sampling (by label)." + + " ==> Approx Sample") + + // Subsample, and count examples per label in sampled data. (approximate) + val sampledByKeyRDDExact = + keyedRDD.sampleByKeyExact(withReplacement = true, fractions = fractions) + val keyCountsBExact = sampledByKeyRDDExact.countByKey() + val sizeBExact = keyCountsBExact.values.sum + println(s" Sampled $sizeBExact examples using exact stratified sampling (by label)." + + " ==> Exact Sample") + + // Compare samples + println(s" \tFractions of examples with key") + println(s"Key\tOrig\tApprox Sample\tExact Sample") + keyCounts.keys.toSeq.sorted.foreach { key => + val origFrac = keyCounts(key) / numExamples.toDouble + val approxFrac = if (sizeB != 0) { + keyCountsB.getOrElse(key, 0L) / sizeB.toDouble + } else { + 0 + } + val exactFrac = if (sizeBExact != 0) { + keyCountsBExact.getOrElse(key, 0L) / sizeBExact.toDouble + } else { + 0 + } + println(s"$key\t$origFrac\t$approxFrac\t$exactFrac") + } + + sc.stop() + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala index e76bc9fefff01..2e414a73be8e0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala @@ -53,8 +53,14 @@ class RowMatrix( /** Gets or computes the number of columns. */ override def numCols(): Long = { if (nCols <= 0) { - // Calling `first` will throw an exception if `rows` is empty. - nCols = rows.first().size + try { + // Calling `first` will throw an exception if `rows` is empty. + nCols = rows.first().size + } catch { + case err: UnsupportedOperationException => + sys.error("Cannot determine the number of cols because it is not specified in the " + + "constructor and the rows RDD is empty.") + } } nCols } @@ -293,6 +299,10 @@ class RowMatrix( (s1._1 + s2._1, s1._2 += s2._2) ) + if (m <= 1) { + sys.error(s"RowMatrix.computeCovariance called on matrix with only $m rows." + + " Cannot compute the covariance of a RowMatrix with <= 1 row.") + } updateNumRows(m) mean :/= m.toDouble diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala index 5105b5c37aaaa..7d845c44365dd 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala @@ -55,8 +55,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S */ def add(sample: Vector): this.type = { if (n == 0) { - require(sample.toBreeze.length > 0, s"Vector should have dimension larger than zero.") - n = sample.toBreeze.length + require(sample.size > 0, s"Vector should have dimension larger than zero.") + n = sample.size currMean = BDV.zeros[Double](n) currM2n = BDV.zeros[Double](n) @@ -65,8 +65,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S currMin = BDV.fill(n)(Double.MaxValue) } - require(n == sample.toBreeze.length, s"Dimensions mismatch when adding new sample." + - s" Expecting $n but got ${sample.toBreeze.length}.") + require(n == sample.size, s"Dimensions mismatch when adding new sample." + + s" Expecting $n but got ${sample.size}.") sample.toBreeze.activeIterator.foreach { case (_, 0.0) => // Skip explicit zero elements. diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala index a3f76f77a5dcc..34548c86ebc14 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala @@ -39,6 +39,17 @@ class CorrelationSuite extends FunSuite with LocalSparkContext { Vectors.dense(9.0, 0.0, 0.0, 1.0) ) + test("corr(x, y) pearson, 1 value in data") { + val x = sc.parallelize(Array(1.0)) + val y = sc.parallelize(Array(4.0)) + intercept[RuntimeException] { + Statistics.corr(x, y, "pearson") + } + intercept[RuntimeException] { + Statistics.corr(x, y, "spearman") + } + } + test("corr(x, y) default, pearson") { val x = sc.parallelize(xData) val y = sc.parallelize(yData) @@ -58,7 +69,7 @@ class CorrelationSuite extends FunSuite with LocalSparkContext { // RDD of zero variance val z = sc.parallelize(zeros) - assert(Statistics.corr(x, z).isNaN()) + assert(Statistics.corr(x, z).isNaN) } test("corr(x, y) spearman") { @@ -78,7 +89,7 @@ class CorrelationSuite extends FunSuite with LocalSparkContext { // RDD of zero variance => zero variance in ranks val z = sc.parallelize(zeros) - assert(Statistics.corr(x, z, "spearman").isNaN()) + assert(Statistics.corr(x, z, "spearman").isNaN) } test("corr(X) default, pearson") { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala index db13f142df517..1e9415249104b 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala @@ -139,7 +139,8 @@ class MultivariateOnlineSummarizerSuite extends FunSuite { assert(summarizer.numNonzeros ~== Vectors.dense(3, 5, 2) absTol 1E-5, "numNonzeros mismatch") assert(summarizer.variance ~== - Vectors.dense(3.857666666666, 7.0456666666666, 2.48166666666666) absTol 1E-5, "variance mismatch") + Vectors.dense(3.857666666666, 7.0456666666666, 2.48166666666666) absTol 1E-5, + "variance mismatch") assert(summarizer.count === 6) } @@ -167,7 +168,8 @@ class MultivariateOnlineSummarizerSuite extends FunSuite { assert(summarizer.numNonzeros ~== Vectors.dense(3, 5, 2) absTol 1E-5, "numNonzeros mismatch") assert(summarizer.variance ~== - Vectors.dense(3.857666666666, 7.0456666666666, 2.48166666666666) absTol 1E-5, "variance mismatch") + Vectors.dense(3.857666666666, 7.0456666666666, 2.48166666666666) absTol 1E-5, + "variance mismatch") assert(summarizer.count === 6) } diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg.py index 9a239abfbbeb1..f485a69db1fa2 100644 --- a/python/pyspark/mllib/linalg.py +++ b/python/pyspark/mllib/linalg.py @@ -23,6 +23,7 @@ SciPy is available in their environment. """ +import numpy from numpy import array, array_equal, ndarray, float64, int32 @@ -160,6 +161,15 @@ def squared_distance(self, other): j += 1 return result + def toArray(self): + """ + Returns a copy of this SparseVector as a 1-dimensional NumPy array. + """ + arr = numpy.zeros(self.size) + for i in xrange(self.indices.size): + arr[self.indices[i]] = self.values[i] + return arr + def __str__(self): inds = "[" + ",".join([str(i) for i in self.indices]) + "]" vals = "[" + ",".join([str(v) for v in self.values]) + "]" diff --git a/python/pyspark/mllib/stat.py b/python/pyspark/mllib/stat.py index a73abc5ff90df..feef0d16cd644 100644 --- a/python/pyspark/mllib/stat.py +++ b/python/pyspark/mllib/stat.py @@ -118,16 +118,18 @@ def corr(x, y=None, method=None): >>> from linalg import Vectors >>> rdd = sc.parallelize([Vectors.dense([1, 0, 0, -2]), Vectors.dense([4, 5, 0, 3]), ... Vectors.dense([6, 7, 0, 8]), Vectors.dense([9, 0, 0, 1])]) - >>> Statistics.corr(rdd) - array([[ 1. , 0.05564149, nan, 0.40047142], - [ 0.05564149, 1. , nan, 0.91359586], - [ nan, nan, 1. , nan], - [ 0.40047142, 0.91359586, nan, 1. ]]) - >>> Statistics.corr(rdd, method="spearman") - array([[ 1. , 0.10540926, nan, 0.4 ], - [ 0.10540926, 1. , nan, 0.9486833 ], - [ nan, nan, 1. , nan], - [ 0.4 , 0.9486833 , nan, 1. ]]) + >>> pearsonCorr = Statistics.corr(rdd) + >>> print str(pearsonCorr).replace('nan', 'NaN') + [[ 1. 0.05564149 NaN 0.40047142] + [ 0.05564149 1. NaN 0.91359586] + [ NaN NaN 1. NaN] + [ 0.40047142 0.91359586 NaN 1. ]] + >>> spearmanCorr = Statistics.corr(rdd, method="spearman") + >>> print str(spearmanCorr).replace('nan', 'NaN') + [[ 1. 0.10540926 NaN 0.4 ] + [ 0.10540926 1. NaN 0.9486833 ] + [ NaN NaN 1. NaN] + [ 0.4 0.9486833 NaN 1. ]] >>> try: ... Statistics.corr(rdd, "spearman") ... print "Method name as second argument without 'method=' shouldn't be allowed." diff --git a/python/run-tests b/python/run-tests index a6271e0cf5fa9..b506559a5e810 100755 --- a/python/run-tests +++ b/python/run-tests @@ -78,6 +78,7 @@ run_test "pyspark/mllib/linalg.py" run_test "pyspark/mllib/random.py" run_test "pyspark/mllib/recommendation.py" run_test "pyspark/mllib/regression.py" +run_test "pyspark/mllib/stat.py" run_test "pyspark/mllib/tests.py" run_test "pyspark/mllib/tree.py" run_test "pyspark/mllib/util.py" From 7d069bf0c57b75b53b449fcc51cf7fd616f8686d Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Mon, 18 Aug 2014 18:20:54 -0700 Subject: [PATCH 186/231] [SPARK-3108][MLLIB] add predictOnValues to StreamingLR and fix predictOn It is useful in streaming to allow users to carry extra data with the prediction, for monitoring the prediction error for example. freeman-lab Author: Xiangrui Meng Closes #2023 from mengxr/predict-on-values and squashes the following commits: cac47b8 [Xiangrui Meng] add classtag 2821b3b [Xiangrui Meng] use mapValues 0925efa [Xiangrui Meng] add predictOnValues to StreamingLR and fix predictOn (cherry picked from commit 217b5e915e2f21f047dfc4be680cd20d58baf9f8) Signed-off-by: Xiangrui Meng --- .../mllib/StreamingLinearRegression.scala | 4 +-- .../regression/StreamingLinearAlgorithm.scala | 31 +++++++++++++++---- 2 files changed, 27 insertions(+), 8 deletions(-) diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLinearRegression.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLinearRegression.scala index 0e992fa9967bb..c5bd5b0b178d9 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLinearRegression.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLinearRegression.scala @@ -59,10 +59,10 @@ object StreamingLinearRegression { val testData = ssc.textFileStream(args(1)).map(LabeledPoint.parse) val model = new StreamingLinearRegressionWithSGD() - .setInitialWeights(Vectors.dense(Array.fill[Double](args(3).toInt)(0))) + .setInitialWeights(Vectors.zeros(args(3).toInt)) model.trainOn(trainingData) - model.predictOn(testData).print() + model.predictOnValues(testData.map(lp => (lp.label, lp.features))).print() ssc.start() ssc.awaitTermination() diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala index b8b0b42611775..8db0442a7a569 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala @@ -17,8 +17,12 @@ package org.apache.spark.mllib.regression -import org.apache.spark.annotation.DeveloperApi +import scala.reflect.ClassTag + import org.apache.spark.Logging +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.streaming.StreamingContext._ import org.apache.spark.streaming.dstream.DStream /** @@ -92,15 +96,30 @@ abstract class StreamingLinearAlgorithm[ /** * Use the model to make predictions on batches of data from a DStream * - * @param data DStream containing labeled data + * @param data DStream containing feature vectors * @return DStream containing predictions */ - def predictOn(data: DStream[LabeledPoint]): DStream[Double] = { + def predictOn(data: DStream[Vector]): DStream[Double] = { if (Option(model.weights) == None) { - logError("Initial weights must be set before starting prediction") - throw new IllegalArgumentException + val msg = "Initial weights must be set before starting prediction" + logError(msg) + throw new IllegalArgumentException(msg) } - data.map(x => model.predict(x.features)) + data.map(model.predict) } + /** + * Use the model to make predictions on the values of a DStream and carry over its keys. + * @param data DStream containing feature vectors + * @tparam K key type + * @return DStream containing the input keys and the predictions as values + */ + def predictOnValues[K: ClassTag](data: DStream[(K, Vector)]): DStream[(K, Double)] = { + if (Option(model.weights) == None) { + val msg = "Initial weights must be set before starting prediction" + logError(msg) + throw new IllegalArgumentException(msg) + } + data.mapValues(model.predict) + } } From 3a03259a0421b08269a3b23cab2bdb4f9108f5c5 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 18 Aug 2014 20:42:19 -0700 Subject: [PATCH 187/231] [SPARK-3114] [PySpark] Fix Python UDFs in Spark SQL. This fixes SPARK-3114, an issue where we inadvertently broke Python UDFs in Spark SQL. This PR modifiers the test runner script to always run the PySpark SQL tests, irrespective of whether SparkSQL itself has been modified. It also includes Davies' fix for the bug. Closes #2026. Author: Josh Rosen Author: Davies Liu Closes #2027 from JoshRosen/pyspark-sql-fix and squashes the following commits: 9af2708 [Davies Liu] bugfix: disable compression of command 0d8d3a4 [Josh Rosen] Always run Python Spark SQL tests. (cherry picked from commit 1f1819b20f887b487557c31e54b8bcd95b582dc6) Signed-off-by: Josh Rosen --- dev/run-tests | 17 +++++++++++++---- python/pyspark/rdd.py | 2 +- python/pyspark/worker.py | 2 +- python/run-tests | 4 +--- 4 files changed, 16 insertions(+), 9 deletions(-) diff --git a/dev/run-tests b/dev/run-tests index 0e24515d1376c..132f696d6447a 100755 --- a/dev/run-tests +++ b/dev/run-tests @@ -58,7 +58,7 @@ if [ -n "$AMPLAB_JENKINS" ]; then diffs=`git diff --name-only master | grep "^sql/"` if [ -n "$diffs" ]; then echo "Detected changes in SQL. Will run Hive test suite." - export _RUN_SQL_TESTS=true # exported for PySpark tests + _RUN_SQL_TESTS=true fi fi @@ -89,13 +89,22 @@ echo "=========================================================================" echo "Running Spark unit tests" echo "=========================================================================" +# Build Spark; we always build with Hive because the PySpark SparkSQL tests need it. +# echo "q" is needed because sbt on encountering a build file with failure +# (either resolution or compilation) prompts the user for input either q, r, +# etc to quit or retry. This echo is there to make it not block. +BUILD_MVN_PROFILE_ARGS="$SBT_MAVEN_PROFILES_ARGS -Phive -Phive-thriftserver " +echo -e "q\n" | sbt/sbt $BUILD_MVN_PROFILE_ARGS clean package assembly/assembly | \ + grep -v -e "info.*Resolving" -e "warn.*Merging" -e "info.*Including" + +# If the Spark SQL tests are enabled, run the tests with the Hive profiles enabled: if [ -n "$_RUN_SQL_TESTS" ]; then SBT_MAVEN_PROFILES_ARGS="$SBT_MAVEN_PROFILES_ARGS -Phive -Phive-thriftserver" fi -# echo "q" is needed because sbt on encountering a build file with failure -# (either resolution or compilation) prompts the user for input either q, r, +# echo "q" is needed because sbt on encountering a build file with failure +# (either resolution or compilation) prompts the user for input either q, r, # etc to quit or retry. This echo is there to make it not block. -echo -e "q\n" | sbt/sbt $SBT_MAVEN_PROFILES_ARGS clean package assembly/assembly test | \ +echo -e "q\n" | sbt/sbt $SBT_MAVEN_PROFILES_ARGS test | \ grep -v -e "info.*Resolving" -e "warn.*Merging" -e "info.*Including" echo "" diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index c708b69cc1e31..86cd89b245aea 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -1812,7 +1812,7 @@ def _jrdd(self): self._jrdd_deserializer = NoOpSerializer() command = (self.func, self._prev_jrdd_deserializer, self._jrdd_deserializer) - ser = CompressedSerializer(CloudPickleSerializer()) + ser = CloudPickleSerializer() pickled_command = ser.dumps(command) broadcast_vars = ListConverter().convert( [x._jbroadcast for x in self.ctx._pickled_broadcast_vars], diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 77a9c4a0e0677..6805063e06798 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -72,7 +72,7 @@ def main(infile, outfile): value = ser._read_with_length(infile) _broadcastRegistry[bid] = Broadcast(bid, value) - command = ser._read_with_length(infile) + command = pickleSer._read_with_length(infile) (func, deserializer, serializer) = command init_time = time.time() iterator = deserializer.load_stream(infile) diff --git a/python/run-tests b/python/run-tests index b506559a5e810..7b1ee3e1cddba 100755 --- a/python/run-tests +++ b/python/run-tests @@ -59,9 +59,7 @@ $PYSPARK_PYTHON --version run_test "pyspark/rdd.py" run_test "pyspark/context.py" run_test "pyspark/conf.py" -if [ -n "$_RUN_SQL_TESTS" ]; then - run_test "pyspark/sql.py" -fi +run_test "pyspark/sql.py" # These tests are included in the module-level docs, and so must # be handled on a higher level rather than within the python file. export PYSPARK_DOC_TEST=1 From b6d8e66517f264e8576c785624fee9d1ff26900b Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 18 Aug 2014 20:51:41 -0700 Subject: [PATCH 188/231] [SPARK-3116] Remove the excessive lockings in TorrentBroadcast Author: Reynold Xin Closes #2028 from rxin/torrentBroadcast and squashes the following commits: 92c62a5 [Reynold Xin] Revert the MEMORY_AND_DISK_SER changes. 03a5221 [Reynold Xin] [SPARK-3116] Remove the excessive lockings in TorrentBroadcast (cherry picked from commit 82577339dd58b5811eab5d10667775e61e37ff51) Signed-off-by: Reynold Xin --- .../spark/broadcast/TorrentBroadcast.scala | 66 ++++++++----------- 1 file changed, 27 insertions(+), 39 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index fe73456ef8fad..d8be649f96e5f 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -17,8 +17,7 @@ package org.apache.spark.broadcast -import java.io.{ByteArrayOutputStream, ByteArrayInputStream, InputStream, - ObjectInputStream, ObjectOutputStream, OutputStream} +import java.io._ import scala.reflect.ClassTag import scala.util.Random @@ -53,10 +52,8 @@ private[spark] class TorrentBroadcast[T: ClassTag]( private val broadcastId = BroadcastBlockId(id) - TorrentBroadcast.synchronized { - SparkEnv.get.blockManager.putSingle( - broadcastId, value_, StorageLevel.MEMORY_AND_DISK, tellMaster = false) - } + SparkEnv.get.blockManager.putSingle( + broadcastId, value_, StorageLevel.MEMORY_AND_DISK, tellMaster = false) @transient private var arrayOfBlocks: Array[TorrentBlock] = null @transient private var totalBlocks = -1 @@ -91,18 +88,14 @@ private[spark] class TorrentBroadcast[T: ClassTag]( // Store meta-info val metaId = BroadcastBlockId(id, "meta") val metaInfo = TorrentInfo(null, totalBlocks, totalBytes) - TorrentBroadcast.synchronized { - SparkEnv.get.blockManager.putSingle( - metaId, metaInfo, StorageLevel.MEMORY_AND_DISK, tellMaster = true) - } + SparkEnv.get.blockManager.putSingle( + metaId, metaInfo, StorageLevel.MEMORY_AND_DISK, tellMaster = true) // Store individual pieces for (i <- 0 until totalBlocks) { val pieceId = BroadcastBlockId(id, "piece" + i) - TorrentBroadcast.synchronized { - SparkEnv.get.blockManager.putSingle( - pieceId, tInfo.arrayOfBlocks(i), StorageLevel.MEMORY_AND_DISK, tellMaster = true) - } + SparkEnv.get.blockManager.putSingle( + pieceId, tInfo.arrayOfBlocks(i), StorageLevel.MEMORY_AND_DISK, tellMaster = true) } } @@ -165,21 +158,20 @@ private[spark] class TorrentBroadcast[T: ClassTag]( val metaId = BroadcastBlockId(id, "meta") var attemptId = 10 while (attemptId > 0 && totalBlocks == -1) { - TorrentBroadcast.synchronized { - SparkEnv.get.blockManager.getSingle(metaId) match { - case Some(x) => - val tInfo = x.asInstanceOf[TorrentInfo] - totalBlocks = tInfo.totalBlocks - totalBytes = tInfo.totalBytes - arrayOfBlocks = new Array[TorrentBlock](totalBlocks) - hasBlocks = 0 - - case None => - Thread.sleep(500) - } + SparkEnv.get.blockManager.getSingle(metaId) match { + case Some(x) => + val tInfo = x.asInstanceOf[TorrentInfo] + totalBlocks = tInfo.totalBlocks + totalBytes = tInfo.totalBytes + arrayOfBlocks = new Array[TorrentBlock](totalBlocks) + hasBlocks = 0 + + case None => + Thread.sleep(500) } attemptId -= 1 } + if (totalBlocks == -1) { return false } @@ -192,17 +184,15 @@ private[spark] class TorrentBroadcast[T: ClassTag]( val recvOrder = new Random().shuffle(Array.iterate(0, totalBlocks)(_ + 1).toList) for (pid <- recvOrder) { val pieceId = BroadcastBlockId(id, "piece" + pid) - TorrentBroadcast.synchronized { - SparkEnv.get.blockManager.getSingle(pieceId) match { - case Some(x) => - arrayOfBlocks(pid) = x.asInstanceOf[TorrentBlock] - hasBlocks += 1 - SparkEnv.get.blockManager.putSingle( - pieceId, arrayOfBlocks(pid), StorageLevel.MEMORY_AND_DISK, tellMaster = true) + SparkEnv.get.blockManager.getSingle(pieceId) match { + case Some(x) => + arrayOfBlocks(pid) = x.asInstanceOf[TorrentBlock] + hasBlocks += 1 + SparkEnv.get.blockManager.putSingle( + pieceId, arrayOfBlocks(pid), StorageLevel.MEMORY_AND_DISK, tellMaster = true) - case None => - throw new SparkException("Failed to get " + pieceId + " of " + broadcastId) - } + case None => + throw new SparkException("Failed to get " + pieceId + " of " + broadcastId) } } @@ -291,9 +281,7 @@ private[broadcast] object TorrentBroadcast extends Logging { * If removeFromDriver is true, also remove these persisted blocks on the driver. */ def unpersist(id: Long, removeFromDriver: Boolean, blocking: Boolean) = { - synchronized { - SparkEnv.get.blockManager.master.removeBroadcast(id, removeFromDriver, blocking) - } + SparkEnv.get.blockManager.master.removeBroadcast(id, removeFromDriver, blocking) } } From f3b0f34b4403bceb9b98740084b9ddba4314d71a Mon Sep 17 00:00:00 2001 From: Matt Forbes Date: Mon, 18 Aug 2014 21:43:32 -0700 Subject: [PATCH 189/231] Fix typo in decision tree docs Candidate splits were inconsistent with the example. Author: Matt Forbes Closes #1837 from emef/tree-doc and squashes the following commits: 3be14a1 [Matt Forbes] Fix typo in decision tree docs (cherry picked from commit cd0720ca77894d481fb73a8b5bb517013843cb1e) Signed-off-by: Xiangrui Meng --- docs/mllib-decision-tree.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/mllib-decision-tree.md b/docs/mllib-decision-tree.md index 9cbd880897578..c01a92a9a1b26 100644 --- a/docs/mllib-decision-tree.md +++ b/docs/mllib-decision-tree.md @@ -84,8 +84,8 @@ Section 9.2.4 in [Elements of Statistical Machine Learning](http://statweb.stanford.edu/~tibs/ElemStatLearn/) for details). For example, for a binary classification problem with one categorical feature with three categories A, B and C with corresponding proportion of label 1 as 0.2, 0.6 and 0.4, the categorical -features are ordered as A followed by C followed B or A, B, C. The two split candidates are A \| C, B -and A , B \| C where \| denotes the split. A similar heuristic is used for multiclass classification +features are ordered as A followed by C followed B or A, C, B. The two split candidates are A \| C, B +and A , C \| B where \| denotes the split. A similar heuristic is used for multiclass classification when `$2^(M-1)-1$` is greater than the number of bins -- the impurity for each categorical feature value is used for ordering. From 1418893da557892b86fc47f1e41e91880d4f8eda Mon Sep 17 00:00:00 2001 From: Thomas Graves Date: Tue, 19 Aug 2014 09:40:31 -0500 Subject: [PATCH 190/231] [SPARK-3072] YARN - Exit when reach max number failed executors In some cases on hadoop 2.x the spark application master doesn't properly exit and hangs around for 10 minutes after its really done. We should make sure it exits properly and stops the driver. Author: Thomas Graves Closes #2022 from tgravescs/SPARK-3072 and squashes the following commits: 665701d [Thomas Graves] Exit when reach max number failed executors (cherry picked from commit 7eb9cbc273d758522e787fcb2ef68ef65911475f) Signed-off-by: Thomas Graves --- .../spark/deploy/yarn/ApplicationMaster.scala | 33 ++++++++++++------- .../spark/deploy/yarn/ExecutorLauncher.scala | 5 +-- .../spark/deploy/yarn/ApplicationMaster.scala | 16 ++++++--- .../spark/deploy/yarn/ExecutorLauncher.scala | 5 +-- 4 files changed, 40 insertions(+), 19 deletions(-) diff --git a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 62b5c3bc5f0f3..46a01f5a9a2cc 100644 --- a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -267,12 +267,10 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, // TODO: This is a bit ugly. Can we make it nicer? // TODO: Handle container failure - // Exists the loop if the user thread exits. - while (yarnAllocator.getNumExecutorsRunning < args.numExecutors && userThread.isAlive) { - if (yarnAllocator.getNumExecutorsFailed >= maxNumExecutorFailures) { - finishApplicationMaster(FinalApplicationStatus.FAILED, - "max number of executor failures reached") - } + // Exits the loop if the user thread exits. + while (yarnAllocator.getNumExecutorsRunning < args.numExecutors && userThread.isAlive + && !isFinished) { + checkNumExecutorsFailed() yarnAllocator.allocateContainers( math.max(args.numExecutors - yarnAllocator.getNumExecutorsRunning, 0)) Thread.sleep(ApplicationMaster.ALLOCATE_HEARTBEAT_INTERVAL) @@ -303,11 +301,8 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, val t = new Thread { override def run() { - while (userThread.isAlive) { - if (yarnAllocator.getNumExecutorsFailed >= maxNumExecutorFailures) { - finishApplicationMaster(FinalApplicationStatus.FAILED, - "max number of executor failures reached") - } + while (userThread.isAlive && !isFinished) { + checkNumExecutorsFailed() val missingExecutorCount = args.numExecutors - yarnAllocator.getNumExecutorsRunning if (missingExecutorCount > 0) { logInfo("Allocating %d containers to make up for (potentially) lost containers". @@ -327,6 +322,22 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, t } + private def checkNumExecutorsFailed() { + if (yarnAllocator.getNumExecutorsFailed >= maxNumExecutorFailures) { + logInfo("max number of executor failures reached") + finishApplicationMaster(FinalApplicationStatus.FAILED, + "max number of executor failures reached") + // make sure to stop the user thread + val sparkContext = ApplicationMaster.sparkContextRef.get() + if (sparkContext != null) { + logInfo("Invoking sc stop from checkNumExecutorsFailed") + sparkContext.stop() + } else { + logError("sparkContext is null when should shutdown") + } + } + } + private def sendProgress() { logDebug("Sending progress") // Simulated with an allocate request with no nodes requested ... diff --git a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala index 184e2ad6c82cd..72c7143edcd71 100644 --- a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala +++ b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala @@ -249,7 +249,8 @@ class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sp // Wait until all containers have finished // TODO: This is a bit ugly. Can we make it nicer? // TODO: Handle container failure - while ((yarnAllocator.getNumExecutorsRunning < args.numExecutors) && (!driverClosed)) { + while ((yarnAllocator.getNumExecutorsRunning < args.numExecutors) && (!driverClosed) && + !isFinished) { yarnAllocator.allocateContainers( math.max(args.numExecutors - yarnAllocator.getNumExecutorsRunning, 0)) checkNumExecutorsFailed() @@ -271,7 +272,7 @@ class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sp val t = new Thread { override def run() { - while (!driverClosed) { + while (!driverClosed && !isFinished) { checkNumExecutorsFailed() val missingExecutorCount = args.numExecutors - yarnAllocator.getNumExecutorsRunning if (missingExecutorCount > 0) { diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 035356d390c80..9c2bcf17a8508 100644 --- a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -247,13 +247,12 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, yarnAllocator.allocateResources() // Exits the loop if the user thread exits. - var iters = 0 - while (yarnAllocator.getNumExecutorsRunning < args.numExecutors && userThread.isAlive) { + while (yarnAllocator.getNumExecutorsRunning < args.numExecutors && userThread.isAlive + && !isFinished) { checkNumExecutorsFailed() allocateMissingExecutor() yarnAllocator.allocateResources() Thread.sleep(ApplicationMaster.ALLOCATE_HEARTBEAT_INTERVAL) - iters += 1 } } logInfo("All executors have launched.") @@ -271,8 +270,17 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, private def checkNumExecutorsFailed() { if (yarnAllocator.getNumExecutorsFailed >= maxNumExecutorFailures) { + logInfo("max number of executor failures reached") finishApplicationMaster(FinalApplicationStatus.FAILED, "max number of executor failures reached") + // make sure to stop the user thread + val sparkContext = ApplicationMaster.sparkContextRef.get() + if (sparkContext != null) { + logInfo("Invoking sc stop from checkNumExecutorsFailed") + sparkContext.stop() + } else { + logError("sparkContext is null when should shutdown") + } } } @@ -289,7 +297,7 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, val t = new Thread { override def run() { - while (userThread.isAlive) { + while (userThread.isAlive && !isFinished) { checkNumExecutorsFailed() allocateMissingExecutor() logDebug("Sending progress") diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala index fc7b8320d734d..a7585748b7f88 100644 --- a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala +++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala @@ -217,7 +217,8 @@ class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sp // Wait until all containers have launched yarnAllocator.addResourceRequests(args.numExecutors) yarnAllocator.allocateResources() - while ((yarnAllocator.getNumExecutorsRunning < args.numExecutors) && (!driverClosed)) { + while ((yarnAllocator.getNumExecutorsRunning < args.numExecutors) && (!driverClosed) && + !isFinished) { checkNumExecutorsFailed() allocateMissingExecutor() yarnAllocator.allocateResources() @@ -249,7 +250,7 @@ class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sp val t = new Thread { override def run() { - while (!driverClosed) { + while (!driverClosed && !isFinished) { checkNumExecutorsFailed() allocateMissingExecutor() logDebug("Sending progress") From 5d895ad5668823a52b143ac39d9ffa264fc2a7b2 Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Tue, 19 Aug 2014 10:15:11 -0700 Subject: [PATCH 191/231] [SPARK-3089] Fix meaningless error message in ConnectionManager Author: Kousuke Saruta Closes #2000 from sarutak/SPARK-3089 and squashes the following commits: 02dfdea [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into SPARK-3089 e759ce7 [Kousuke Saruta] Improved error message when closing SendingConnection (cherry picked from commit cbfc26ba45f49559e64276c72e3054c6fe30ddd5) Signed-off-by: Josh Rosen --- .../main/scala/org/apache/spark/network/ConnectionManager.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala index e77d762bdf221..b3e951ded6e77 100644 --- a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala +++ b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala @@ -467,7 +467,7 @@ private[spark] class ConnectionManager( val sendingConnectionOpt = connectionsById.get(remoteConnectionManagerId) if (!sendingConnectionOpt.isDefined) { - logError("Corresponding SendingConnectionManagerId not found") + logError(s"Corresponding SendingConnection to ${remoteConnectionManagerId} not found") return } From 04a32086212452d3488e12dd64ffa18af0243345 Mon Sep 17 00:00:00 2001 From: freeman Date: Tue, 19 Aug 2014 13:28:57 -0700 Subject: [PATCH 192/231] [SPARK-3128][MLLIB] Use streaming test suite for StreamingLR Refactored tests for streaming linear regression to use existing streaming test utilities. Summary of changes: - Made ``mllib`` depend on tests from ``streaming`` - Rewrote accuracy and convergence tests to use ``setupStreams`` and ``runStreams`` - Added new test for the accuracy of predictions generated by ``predictOnValue`` These tests should run faster, be easier to extend/maintain, and provide a reference for new tests. mengxr tdas Author: freeman Closes #2037 from freeman-lab/streamingLR-predict-tests and squashes the following commits: e851ca7 [freeman] Fixed long lines 50eb0bf [freeman] Refactored tests to use streaming test tools 32c43c2 [freeman] Added test for prediction (cherry picked from commit 31f0b071efd0b63eb9d6a6a131e5c4fa28237583) Signed-off-by: Tathagata Das --- mllib/pom.xml | 7 + .../StreamingLinearRegressionSuite.scala | 121 ++++++++++-------- .../spark/streaming/TestSuiteBase.scala | 4 +- 3 files changed, 77 insertions(+), 55 deletions(-) diff --git a/mllib/pom.xml b/mllib/pom.xml index d5c2e5ab54caa..74f528f030987 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -91,6 +91,13 @@ junit-interface test + + org.apache.spark + spark-streaming_${scala.binary.version} + ${project.version} + test-jar + test + diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala index 45e25eecf508e..28489410f8225 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala @@ -17,20 +17,19 @@ package org.apache.spark.mllib.regression -import java.io.File -import java.nio.charset.Charset - import scala.collection.mutable.ArrayBuffer -import com.google.common.io.Files import org.scalatest.FunSuite import org.apache.spark.mllib.linalg.Vectors -import org.apache.spark.mllib.util.{LinearDataGenerator, LocalSparkContext} -import org.apache.spark.streaming.{Milliseconds, StreamingContext} -import org.apache.spark.util.Utils +import org.apache.spark.mllib.util.LinearDataGenerator +import org.apache.spark.streaming.dstream.DStream +import org.apache.spark.streaming.TestSuiteBase + +class StreamingLinearRegressionSuite extends FunSuite with TestSuiteBase { -class StreamingLinearRegressionSuite extends FunSuite with LocalSparkContext { + // use longer wait time to ensure job completion + override def maxWaitTimeMillis = 20000 // Assert that two values are equal within tolerance epsilon def assertEqual(v1: Double, v2: Double, epsilon: Double) { @@ -49,35 +48,26 @@ class StreamingLinearRegressionSuite extends FunSuite with LocalSparkContext { } // Test if we can accurately learn Y = 10*X1 + 10*X2 on streaming data - test("streaming linear regression parameter accuracy") { + test("parameter accuracy") { - val testDir = Files.createTempDir() - val numBatches = 10 - val batchDuration = Milliseconds(1000) - val ssc = new StreamingContext(sc, batchDuration) - val data = ssc.textFileStream(testDir.toString).map(LabeledPoint.parse) + // create model val model = new StreamingLinearRegressionWithSGD() .setInitialWeights(Vectors.dense(0.0, 0.0)) .setStepSize(0.1) - .setNumIterations(50) + .setNumIterations(25) - model.trainOn(data) - - ssc.start() - - // write data to a file stream - for (i <- 0 until numBatches) { - val samples = LinearDataGenerator.generateLinearInput( - 0.0, Array(10.0, 10.0), 100, 42 * (i + 1)) - val file = new File(testDir, i.toString) - Files.write(samples.map(x => x.toString).mkString("\n"), file, Charset.forName("UTF-8")) - Thread.sleep(batchDuration.milliseconds) + // generate sequence of simulated data + val numBatches = 10 + val input = (0 until numBatches).map { i => + LinearDataGenerator.generateLinearInput(0.0, Array(10.0, 10.0), 100, 42 * (i + 1)) } - ssc.stop(stopSparkContext=false) - - System.clearProperty("spark.driver.port") - Utils.deleteRecursively(testDir) + // apply model training to input stream + val ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => { + model.trainOn(inputDStream) + inputDStream.count() + }) + runStreams(ssc, numBatches, numBatches) // check accuracy of final parameter estimates assertEqual(model.latestModel().intercept, 0.0, 0.1) @@ -91,39 +81,33 @@ class StreamingLinearRegressionSuite extends FunSuite with LocalSparkContext { } // Test that parameter estimates improve when learning Y = 10*X1 on streaming data - test("streaming linear regression parameter convergence") { + test("parameter convergence") { - val testDir = Files.createTempDir() - val batchDuration = Milliseconds(2000) - val ssc = new StreamingContext(sc, batchDuration) - val numBatches = 5 - val data = ssc.textFileStream(testDir.toString()).map(LabeledPoint.parse) + // create model val model = new StreamingLinearRegressionWithSGD() .setInitialWeights(Vectors.dense(0.0)) .setStepSize(0.1) - .setNumIterations(50) - - model.trainOn(data) - - ssc.start() + .setNumIterations(25) - // write data to a file stream - val history = new ArrayBuffer[Double](numBatches) - for (i <- 0 until numBatches) { - val samples = LinearDataGenerator.generateLinearInput(0.0, Array(10.0), 100, 42 * (i + 1)) - val file = new File(testDir, i.toString) - Files.write(samples.map(x => x.toString).mkString("\n"), file, Charset.forName("UTF-8")) - Thread.sleep(batchDuration.milliseconds) - // wait an extra few seconds to make sure the update finishes before new data arrive - Thread.sleep(4000) - history.append(math.abs(model.latestModel().weights(0) - 10.0)) + // generate sequence of simulated data + val numBatches = 10 + val input = (0 until numBatches).map { i => + LinearDataGenerator.generateLinearInput(0.0, Array(10.0), 100, 42 * (i + 1)) } - ssc.stop(stopSparkContext=false) + // create buffer to store intermediate fits + val history = new ArrayBuffer[Double](numBatches) - System.clearProperty("spark.driver.port") - Utils.deleteRecursively(testDir) + // apply model training to input stream, storing the intermediate results + // (we add a count to ensure the result is a DStream) + val ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => { + model.trainOn(inputDStream) + inputDStream.foreachRDD(x => history.append(math.abs(model.latestModel().weights(0) - 10.0))) + inputDStream.count() + }) + runStreams(ssc, numBatches, numBatches) + // compute change in error val deltas = history.drop(1).zip(history.dropRight(1)) // check error stability (it always either shrinks, or increases with small tol) assert(deltas.forall(x => (x._1 - x._2) <= 0.1)) @@ -132,4 +116,33 @@ class StreamingLinearRegressionSuite extends FunSuite with LocalSparkContext { } + // Test predictions on a stream + test("predictions") { + + // create model initialized with true weights + val model = new StreamingLinearRegressionWithSGD() + .setInitialWeights(Vectors.dense(10.0, 10.0)) + .setStepSize(0.1) + .setNumIterations(25) + + // generate sequence of simulated data for testing + val numBatches = 10 + val nPoints = 100 + val testInput = (0 until numBatches).map { i => + LinearDataGenerator.generateLinearInput(0.0, Array(10.0, 10.0), nPoints, 42 * (i + 1)) + } + + // apply model predictions to test stream + val ssc = setupStreams(testInput, (inputDStream: DStream[LabeledPoint]) => { + model.predictOnValues(inputDStream.map(x => (x.label, x.features))) + }) + // collect the output as (true, estimated) tuples + val output: Seq[Seq[(Double, Double)]] = runStreams(ssc, numBatches, numBatches) + + // compute the mean absolute error and check that it's always less than 0.1 + val errors = output.map(batch => batch.map(p => math.abs(p._1 - p._2)).sum / nPoints) + assert(errors.forall(x => x <= 0.1)) + + } + } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala index cc178fba12c9d..f095da9cb55d3 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala @@ -242,7 +242,9 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging { logInfo("numBatches = " + numBatches + ", numExpectedOutput = " + numExpectedOutput) // Get the output buffer - val outputStream = ssc.graph.getOutputStreams.head.asInstanceOf[TestOutputStreamWithPartitions[V]] + val outputStream = ssc.graph.getOutputStreams. + filter(_.isInstanceOf[TestOutputStreamWithPartitions[_]]). + head.asInstanceOf[TestOutputStreamWithPartitions[V]] val output = outputStream.output try { From c3952b092a2f7fea4798f4cb7abac300b9dc9c29 Mon Sep 17 00:00:00 2001 From: Vida Ha Date: Tue, 19 Aug 2014 13:35:05 -0700 Subject: [PATCH 193/231] SPARK-2333 - spark_ec2 script should allow option for existing security group - Uses the name tag to identify machines in a cluster. - Allows overriding the security group name so it doesn't need to coincide with the cluster name. - Outputs the request id's of up to 10 pending spot instance requests. Author: Vida Ha Closes #1899 from vidaha/vida/ec2-reuse-security-group and squashes the following commits: c80d5c3 [Vida Ha] wrap retries in a try catch block b2989d5 [Vida Ha] SPARK-2333: spark_ec2 script should allow option for existing security group (cherry picked from commit 94053a7b766788bb62e2dbbf352ccbcc75f71fc0) Signed-off-by: Josh Rosen --- docs/ec2-scripts.md | 14 +++++---- ec2/spark_ec2.py | 71 +++++++++++++++++++++++++++++++-------------- 2 files changed, 57 insertions(+), 28 deletions(-) diff --git a/docs/ec2-scripts.md b/docs/ec2-scripts.md index 156a727026790..f5ac6d894e1eb 100644 --- a/docs/ec2-scripts.md +++ b/docs/ec2-scripts.md @@ -12,14 +12,16 @@ on the [Amazon Web Services site](http://aws.amazon.com/). `spark-ec2` is designed to manage multiple named clusters. You can launch a new cluster (telling the script its size and giving it a name), -shutdown an existing cluster, or log into a cluster. Each cluster is -identified by placing its machines into EC2 security groups whose names -are derived from the name of the cluster. For example, a cluster named +shutdown an existing cluster, or log into a cluster. Each cluster +launches a set of instances, which are tagged with the cluster name, +and placed into EC2 security groups. If you don't specify a security +group, the `spark-ec2` script will create security groups based on the +cluster name you request. For example, a cluster named `test` will contain a master node in a security group called `test-master`, and a number of slave nodes in a security group called -`test-slaves`. The `spark-ec2` script will create these security groups -for you based on the cluster name you request. You can also use them to -identify machines belonging to each cluster in the Amazon EC2 Console. +`test-slaves`. You can also specify a security group prefix to be used +in place of the cluster name. Machines in a cluster can be identified +by looking for the "Name" tag of the instance in the Amazon EC2 Console. # Before You Start diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py index fc6fb1db59424..a979891662fb7 100755 --- a/ec2/spark_ec2.py +++ b/ec2/spark_ec2.py @@ -124,7 +124,7 @@ def parse_args(): help="The SSH user you want to connect as (default: root)") parser.add_option( "--delete-groups", action="store_true", default=False, - help="When destroying a cluster, delete the security groups that were created") + help="When destroying a cluster, delete the security groups that were created.") parser.add_option( "--use-existing-master", action="store_true", default=False, help="Launch fresh slaves, but use an existing stopped master if possible") @@ -138,7 +138,9 @@ def parse_args(): parser.add_option( "--user-data", type="string", default="", help="Path to a user-data file (most AMI's interpret this as an initialization script)") - + parser.add_option( + "--security-group-prefix", type="string", default=None, + help="Use this prefix for the security group rather than the cluster name.") (opts, args) = parser.parse_args() if len(args) != 2: @@ -285,8 +287,12 @@ def launch_cluster(conn, opts, cluster_name): user_data_content = user_data_file.read() print "Setting up security groups..." - master_group = get_or_make_group(conn, cluster_name + "-master") - slave_group = get_or_make_group(conn, cluster_name + "-slaves") + if opts.security_group_prefix is None: + master_group = get_or_make_group(conn, cluster_name + "-master") + slave_group = get_or_make_group(conn, cluster_name + "-slaves") + else: + master_group = get_or_make_group(conn, opts.security_group_prefix + "-master") + slave_group = get_or_make_group(conn, opts.security_group_prefix + "-slaves") if master_group.rules == []: # Group was just now created master_group.authorize(src_group=master_group) master_group.authorize(src_group=slave_group) @@ -310,12 +316,11 @@ def launch_cluster(conn, opts, cluster_name): slave_group.authorize('tcp', 60060, 60060, '0.0.0.0/0') slave_group.authorize('tcp', 60075, 60075, '0.0.0.0/0') - # Check if instances are already running in our groups + # Check if instances are already running with the cluster name existing_masters, existing_slaves = get_existing_cluster(conn, opts, cluster_name, die_on_error=False) if existing_slaves or (existing_masters and not opts.use_existing_master): - print >> stderr, ("ERROR: There are already instances running in " + - "group %s or %s" % (master_group.name, slave_group.name)) + print >> stderr, ("ERROR: There are already instances for name: %s " % cluster_name) sys.exit(1) # Figure out Spark AMI @@ -371,9 +376,13 @@ def launch_cluster(conn, opts, cluster_name): for r in reqs: id_to_req[r.id] = r active_instance_ids = [] + outstanding_request_ids = [] for i in my_req_ids: - if i in id_to_req and id_to_req[i].state == "active": - active_instance_ids.append(id_to_req[i].instance_id) + if i in id_to_req: + if id_to_req[i].state == "active": + active_instance_ids.append(id_to_req[i].instance_id) + else: + outstanding_request_ids.append(i) if len(active_instance_ids) == opts.slaves: print "All %d slaves granted" % opts.slaves reservations = conn.get_all_instances(active_instance_ids) @@ -382,8 +391,8 @@ def launch_cluster(conn, opts, cluster_name): slave_nodes += r.instances break else: - print "%d of %d slaves granted, waiting longer" % ( - len(active_instance_ids), opts.slaves) + print "%d of %d slaves granted, waiting longer for request ids including %s" % ( + len(active_instance_ids), opts.slaves, outstanding_request_ids[0:10]) except: print "Canceling spot instance requests" conn.cancel_spot_instance_requests(my_req_ids) @@ -440,14 +449,29 @@ def launch_cluster(conn, opts, cluster_name): print "Launched master in %s, regid = %s" % (zone, master_res.id) # Give the instances descriptive names + # TODO: Add retry logic for tagging with name since it's used to identify a cluster. for master in master_nodes: - master.add_tag( - key='Name', - value='{cn}-master-{iid}'.format(cn=cluster_name, iid=master.id)) + name = '{cn}-master-{iid}'.format(cn=cluster_name, iid=master.id) + for i in range(0, 5): + try: + master.add_tag(key='Name', value=name) + except: + print "Failed attempt %i of 5 to tag %s" % ((i + 1), name) + if (i == 5): + raise "Error - failed max attempts to add name tag" + time.sleep(5) + + for slave in slave_nodes: - slave.add_tag( - key='Name', - value='{cn}-slave-{iid}'.format(cn=cluster_name, iid=slave.id)) + name = '{cn}-slave-{iid}'.format(cn=cluster_name, iid=slave.id) + for i in range(0, 5): + try: + slave.add_tag(key='Name', value=name) + except: + print "Failed attempt %i of 5 to tag %s" % ((i + 1), name) + if (i == 5): + raise "Error - failed max attempts to add name tag" + time.sleep(5) # Return all the instances return (master_nodes, slave_nodes) @@ -463,10 +487,10 @@ def get_existing_cluster(conn, opts, cluster_name, die_on_error=True): for res in reservations: active = [i for i in res.instances if is_active(i)] for inst in active: - group_names = [g.name for g in inst.groups] - if group_names == [cluster_name + "-master"]: + name = inst.tags.get(u'Name', "") + if name.startswith(cluster_name + "-master"): master_nodes.append(inst) - elif group_names == [cluster_name + "-slaves"]: + elif name.startswith(cluster_name + "-slave"): slave_nodes.append(inst) if any((master_nodes, slave_nodes)): print ("Found %d master(s), %d slaves" % (len(master_nodes), len(slave_nodes))) @@ -474,7 +498,7 @@ def get_existing_cluster(conn, opts, cluster_name, die_on_error=True): return (master_nodes, slave_nodes) else: if master_nodes == [] and slave_nodes != []: - print >> sys.stderr, "ERROR: Could not find master in group " + cluster_name + "-master" + print >> sys.stderr, "ERROR: Could not find master in with name " + cluster_name + "-master" else: print >> sys.stderr, "ERROR: Could not find any existing cluster" sys.exit(1) @@ -816,7 +840,10 @@ def real_main(): # Delete security groups as well if opts.delete_groups: print "Deleting security groups (this will take some time)..." - group_names = [cluster_name + "-master", cluster_name + "-slaves"] + if opts.security_group_prefix is None: + group_names = [cluster_name + "-master", cluster_name + "-slaves"] + else: + group_names = [opts.security_group_prefix + "-master", opts.security_group_prefix + "-slaves"] attempt = 1 while attempt <= 3: From f6b4ab83c073d84d1ca26f2ed1168fdbd1c928db Mon Sep 17 00:00:00 2001 From: hzw19900416 Date: Tue, 19 Aug 2014 14:04:49 -0700 Subject: [PATCH 194/231] Move a bracket in validateSettings of SparkConf Move a bracket in validateSettings of SparkConf Author: hzw19900416 Closes #2012 from hzw19900416/codereading and squashes the following commits: e717fb6 [hzw19900416] Move a bracket in validateSettings of SparkConf (cherry picked from commit 76eaeb4523ee01cabbea2d867daac48a277885a1) Signed-off-by: Josh Rosen --- core/src/main/scala/org/apache/spark/SparkConf.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index 13f0bff7ee507..b4f321ec99e78 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -227,7 +227,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { // Validate spark.executor.extraJavaOptions settings.get(executorOptsKey).map { javaOpts => if (javaOpts.contains("-Dspark")) { - val msg = s"$executorOptsKey is not allowed to set Spark options (was '$javaOpts)'. " + + val msg = s"$executorOptsKey is not allowed to set Spark options (was '$javaOpts'). " + "Set them directly on a SparkConf or in a properties file when using ./bin/spark-submit." throw new Exception(msg) } From 3540d4b387568a4017fcd772233e4e10c1beb1b4 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 19 Aug 2014 14:46:32 -0700 Subject: [PATCH 195/231] [SPARK-2790] [PySpark] fix zip with serializers which have different batch sizes. If two RDDs have different batch size in serializers, then it will try to re-serialize the one with smaller batch size, then call RDD.zip() in Spark. Author: Davies Liu Closes #1894 from davies/zip and squashes the following commits: c4652ea [Davies Liu] add more test cases 6d05fc8 [Davies Liu] Merge branch 'master' into zip 813b1e4 [Davies Liu] add more tests for failed cases a4aafda [Davies Liu] fix zip with serializers which have different batch sizes. (cherry picked from commit d7e80c2597d4a9cae2e0cb35a86f7889323f4cbb) Signed-off-by: Josh Rosen --- python/pyspark/rdd.py | 25 +++++++++++++++++++++++++ python/pyspark/serializers.py | 3 +++ python/pyspark/tests.py | 27 ++++++++++++++++++++++++++- 3 files changed, 54 insertions(+), 1 deletion(-) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 86cd89b245aea..140cbe05a43b0 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -1687,6 +1687,31 @@ def zip(self, other): >>> x.zip(y).collect() [(0, 1000), (1, 1001), (2, 1002), (3, 1003), (4, 1004)] """ + if self.getNumPartitions() != other.getNumPartitions(): + raise ValueError("Can only zip with RDD which has the same number of partitions") + + def get_batch_size(ser): + if isinstance(ser, BatchedSerializer): + return ser.batchSize + return 0 + + def batch_as(rdd, batchSize): + ser = rdd._jrdd_deserializer + if isinstance(ser, BatchedSerializer): + ser = ser.serializer + return rdd._reserialize(BatchedSerializer(ser, batchSize)) + + my_batch = get_batch_size(self._jrdd_deserializer) + other_batch = get_batch_size(other._jrdd_deserializer) + if my_batch != other_batch: + # use the greatest batchSize to batch the other one. + if my_batch > other_batch: + other = batch_as(other, my_batch) + else: + self = batch_as(self, other_batch) + + # There will be an Exception in JVM if there are different number + # of items in each partitions. pairRDD = self._jrdd.zip(other._jrdd) deserializer = PairDeserializer(self._jrdd_deserializer, other._jrdd_deserializer) diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 74870c0edcf99..fc49aa42dbaf9 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -255,6 +255,9 @@ def __init__(self, key_ser, val_ser): def load_stream(self, stream): for (keys, vals) in self.prepare_keys_values(stream): + if len(keys) != len(vals): + raise ValueError("Can not deserialize RDD with different number of items" + " in pair: (%d, %d)" % (len(keys), len(vals))) for pair in izip(keys, vals): yield pair diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 69d543d9d045d..51bfbb47e53c2 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -39,7 +39,7 @@ from pyspark.context import SparkContext from pyspark.files import SparkFiles -from pyspark.serializers import read_int +from pyspark.serializers import read_int, BatchedSerializer, MarshalSerializer, PickleSerializer from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger _have_scipy = False @@ -339,6 +339,31 @@ def test_large_broadcast(self): m = self.sc.parallelize(range(1), 1).map(lambda x: len(bdata.value)).sum() self.assertEquals(N, m) + def test_zip_with_different_serializers(self): + a = self.sc.parallelize(range(5)) + b = self.sc.parallelize(range(100, 105)) + self.assertEqual(a.zip(b).collect(), [(0, 100), (1, 101), (2, 102), (3, 103), (4, 104)]) + a = a._reserialize(BatchedSerializer(PickleSerializer(), 2)) + b = b._reserialize(MarshalSerializer()) + self.assertEqual(a.zip(b).collect(), [(0, 100), (1, 101), (2, 102), (3, 103), (4, 104)]) + + def test_zip_with_different_number_of_items(self): + a = self.sc.parallelize(range(5), 2) + # different number of partitions + b = self.sc.parallelize(range(100, 106), 3) + self.assertRaises(ValueError, lambda: a.zip(b)) + # different number of batched items in JVM + b = self.sc.parallelize(range(100, 104), 2) + self.assertRaises(Exception, lambda: a.zip(b).count()) + # different number of items in one pair + b = self.sc.parallelize(range(100, 106), 2) + self.assertRaises(Exception, lambda: a.zip(b).count()) + # same total number of items, but different distributions + a = self.sc.parallelize([2, 3], 2).flatMap(range) + b = self.sc.parallelize([3, 2], 2).flatMap(range) + self.assertEquals(a.count(), b.count()) + self.assertRaises(Exception, lambda: a.zip(b).count()) + class TestIO(PySparkTestCase): From d371c71cb19f62b1d2594f92f616abf09d9777a7 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Tue, 19 Aug 2014 16:06:48 -0700 Subject: [PATCH 196/231] [SPARK-3136][MLLIB] Create Java-friendly methods in RandomRDDs Though we don't use default argument for methods in RandomRDDs, it is still not easy for Java users to use because the output type is either `RDD[Double]` or `RDD[Vector]`. Java users should expect `JavaDoubleRDD` and `JavaRDD[Vector]`, respectively. We should create dedicated methods for Java users, and allow default arguments in Scala methods in RandomRDDs, to make life easier for both Java and Scala users. This PR also contains documentation for random data generation. brkyvz Author: Xiangrui Meng Closes #2041 from mengxr/stat-doc and squashes the following commits: fc5eedf [Xiangrui Meng] add missing comma ffde810 [Xiangrui Meng] address comments aef6d07 [Xiangrui Meng] add doc for random data generation b99d94b [Xiangrui Meng] add java-friendly methods to RandomRDDs (cherry picked from commit 825d4fe47b9c4d48de88622dd48dcf83beb8b80a) Signed-off-by: Xiangrui Meng --- docs/mllib-guide.md | 2 +- docs/mllib-stats.md | 74 ++- .../mllib/random/RandomDataGenerator.scala | 18 +- .../spark/mllib/random/RandomRDDs.scala | 476 +++++++----------- .../mllib/random/JavaRandomRDDsSuite.java | 134 +++++ python/pyspark/mllib/random.py | 20 +- 6 files changed, 418 insertions(+), 306 deletions(-) create mode 100644 mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java diff --git a/docs/mllib-guide.md b/docs/mllib-guide.md index 23d5a0c4607af..ca0a84a8c53fd 100644 --- a/docs/mllib-guide.md +++ b/docs/mllib-guide.md @@ -9,7 +9,7 @@ filtering, dimensionality reduction, as well as underlying optimization primitiv * [Data types](mllib-basics.html) * [Basic statistics](mllib-stats.html) - * data generators + * random data generation * stratified sampling * summary statistics * hypothesis testing diff --git a/docs/mllib-stats.md b/docs/mllib-stats.md index ca9ef46c15186..f25dca746ba3a 100644 --- a/docs/mllib-stats.md +++ b/docs/mllib-stats.md @@ -25,7 +25,79 @@ displayTitle: MLlib - Statistics Functionality \newcommand{\zero}{\mathbf{0}} \]` -## Data Generators +## Random data generation + +Random data generation is useful for randomized algorithms, prototyping, and performance testing. +MLlib supports generating random RDDs with i.i.d. values drawn from a given distribution: +uniform, standard normal, or Poisson. + +
    +
    +[`RandomRDDs`](api/scala/index.html#org.apache.spark.mllib.random.RandomRDDs) provides factory +methods to generate random double RDDs or vector RDDs. +The following example generates a random double RDD, whose values follows the standard normal +distribution `N(0, 1)`, and then map it to `N(1, 4)`. + +{% highlight scala %} +import org.apache.spark.SparkContext +import org.apache.spark.mllib.random.RandomRDDs._ + +val sc: SparkContext = ... + +// Generate a random double RDD that contains 1 million i.i.d. values drawn from the +// standard normal distribution `N(0, 1)`, evenly distributed in 10 partitions. +val u = normalRDD(sc, 1000000L, 10) +// Apply a transform to get a random double RDD following `N(1, 4)`. +val v = u.map(x => 1.0 + 2.0 * x) +{% endhighlight %} +
    + +
    +[`RandomRDDs`](api/java/index.html#org.apache.spark.mllib.random.RandomRDDs) provides factory +methods to generate random double RDDs or vector RDDs. +The following example generates a random double RDD, whose values follows the standard normal +distribution `N(0, 1)`, and then map it to `N(1, 4)`. + +{% highlight java %} +import org.apache.spark.SparkContext; +import org.apache.spark.api.JavaDoubleRDD; +import static org.apache.spark.mllib.random.RandomRDDs.*; + +JavaSparkContext jsc = ... + +// Generate a random double RDD that contains 1 million i.i.d. values drawn from the +// standard normal distribution `N(0, 1)`, evenly distributed in 10 partitions. +JavaDoubleRDD u = normalJavaRDD(jsc, 1000000L, 10); +// Apply a transform to get a random double RDD following `N(1, 4)`. +JavaDoubleRDD v = u.map( + new Function() { + public Double call(Double x) { + return 1.0 + 2.0 * x; + } + }); +{% endhighlight %} +
    + +
    +[`RandomRDDs`](api/python/pyspark.mllib.random.RandomRDDs-class.html) provides factory +methods to generate random double RDDs or vector RDDs. +The following example generates a random double RDD, whose values follows the standard normal +distribution `N(0, 1)`, and then map it to `N(1, 4)`. + +{% highlight python %} +from pyspark.mllib.random import RandomRDDs + +sc = ... # SparkContext + +# Generate a random double RDD that contains 1 million i.i.d. values drawn from the +# standard normal distribution `N(0, 1)`, evenly distributed in 10 partitions. +u = RandomRDDs.uniformRDD(sc, 1000000L, 10) +# Apply a transform to get a random double RDD following `N(1, 4)`. +v = u.map(lambda x: 1.0 + 2.0 * x) +{% endhighlight %} +
    + +
    ## Stratified Sampling diff --git a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala index 9cab49f6ed1f0..28179fbc450c0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala @@ -20,14 +20,14 @@ package org.apache.spark.mllib.random import cern.jet.random.Poisson import cern.jet.random.engine.DRand -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.util.random.{XORShiftRandom, Pseudorandom} /** - * :: Experimental :: + * :: DeveloperApi :: * Trait for random data generators that generate i.i.d. data. */ -@Experimental +@DeveloperApi trait RandomDataGenerator[T] extends Pseudorandom with Serializable { /** @@ -43,10 +43,10 @@ trait RandomDataGenerator[T] extends Pseudorandom with Serializable { } /** - * :: Experimental :: + * :: DeveloperApi :: * Generates i.i.d. samples from U[0.0, 1.0] */ -@Experimental +@DeveloperApi class UniformGenerator extends RandomDataGenerator[Double] { // XORShiftRandom for better performance. Thread safety isn't necessary here. @@ -62,10 +62,10 @@ class UniformGenerator extends RandomDataGenerator[Double] { } /** - * :: Experimental :: + * :: DeveloperApi :: * Generates i.i.d. samples from the standard normal distribution. */ -@Experimental +@DeveloperApi class StandardNormalGenerator extends RandomDataGenerator[Double] { // XORShiftRandom for better performance. Thread safety isn't necessary here. @@ -81,12 +81,12 @@ class StandardNormalGenerator extends RandomDataGenerator[Double] { } /** - * :: Experimental :: + * :: DeveloperApi :: * Generates i.i.d. samples from the Poisson distribution with the given mean. * * @param mean mean for the Poisson distribution. */ -@Experimental +@DeveloperApi class PoissonGenerator(val mean: Double) extends RandomDataGenerator[Double] { private var rng = new Poisson(mean, new DRand) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala index 36270369526cd..c5f4b084321f7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala @@ -20,9 +20,10 @@ package org.apache.spark.mllib.random import scala.reflect.ClassTag import org.apache.spark.SparkContext -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{DeveloperApi, Experimental} +import org.apache.spark.api.java.{JavaDoubleRDD, JavaRDD, JavaSparkContext} import org.apache.spark.mllib.linalg.Vector -import org.apache.spark.mllib.rdd.{RandomVectorRDD, RandomRDD} +import org.apache.spark.mllib.rdd.{RandomRDD, RandomVectorRDD} import org.apache.spark.rdd.RDD import org.apache.spark.util.Utils @@ -34,335 +35,279 @@ import org.apache.spark.util.Utils object RandomRDDs { /** - * :: Experimental :: - * Generates an RDD comprised of i.i.d. samples from the uniform distribution on [0.0, 1.0]. + * Generates an RDD comprised of i.i.d. samples from the uniform distribution `U(0.0, 1.0)`. * - * To transform the distribution in the generated RDD from U[0.0, 1.0] to U[a, b], use - * `RandomRDDGenerators.uniformRDD(sc, n, p, seed).map(v => a + (b - a) * v)`. + * To transform the distribution in the generated RDD from `U(0.0, 1.0)` to `U(a, b)`, use + * `RandomRDDs.uniformRDD(sc, n, p, seed).map(v => a + (b - a) * v)`. * * @param sc SparkContext used to create the RDD. * @param size Size of the RDD. - * @param numPartitions Number of partitions in the RDD. - * @param seed Seed for the RNG that generates the seed for the generator in each partition. - * @return RDD[Double] comprised of i.i.d. samples ~ U[0.0, 1.0]. + * @param numPartitions Number of partitions in the RDD (default: `sc.defaultParallelism`). + * @param seed Random seed (default: a random long integer). + * @return RDD[Double] comprised of i.i.d. samples ~ `U(0.0, 1.0)`. */ - @Experimental - def uniformRDD(sc: SparkContext, size: Long, numPartitions: Int, seed: Long): RDD[Double] = { + def uniformRDD( + sc: SparkContext, + size: Long, + numPartitions: Int = 0, + seed: Long = Utils.random.nextLong()): RDD[Double] = { val uniform = new UniformGenerator() - randomRDD(sc, uniform, size, numPartitions, seed) + randomRDD(sc, uniform, size, numPartitionsOrDefault(sc, numPartitions), seed) } /** - * :: Experimental :: - * Generates an RDD comprised of i.i.d. samples from the uniform distribution on [0.0, 1.0]. - * - * To transform the distribution in the generated RDD from U[0.0, 1.0] to U[a, b], use - * `RandomRDDGenerators.uniformRDD(sc, n, p).map(v => a + (b - a) * v)`. - * - * @param sc SparkContext used to create the RDD. - * @param size Size of the RDD. - * @param numPartitions Number of partitions in the RDD. - * @return RDD[Double] comprised of i.i.d. samples ~ U[0.0, 1.0]. + * Java-friendly version of [[RandomRDDs#uniformRDD]]. */ - @Experimental - def uniformRDD(sc: SparkContext, size: Long, numPartitions: Int): RDD[Double] = { - uniformRDD(sc, size, numPartitions, Utils.random.nextLong) + def uniformJavaRDD( + jsc: JavaSparkContext, + size: Long, + numPartitions: Int, + seed: Long): JavaDoubleRDD = { + JavaDoubleRDD.fromRDD(uniformRDD(jsc.sc, size, numPartitions, seed)) } /** - * :: Experimental :: - * Generates an RDD comprised of i.i.d. samples from the uniform distribution on [0.0, 1.0]. - * sc.defaultParallelism used for the number of partitions in the RDD. - * - * To transform the distribution in the generated RDD from U[0.0, 1.0] to U[a, b], use - * `RandomRDDGenerators.uniformRDD(sc, n).map(v => a + (b - a) * v)`. - * - * @param sc SparkContext used to create the RDD. - * @param size Size of the RDD. - * @return RDD[Double] comprised of i.i.d. samples ~ U[0.0, 1.0]. + * [[RandomRDDs#uniformJavaRDD]] with the default seed. */ - @Experimental - def uniformRDD(sc: SparkContext, size: Long): RDD[Double] = { - uniformRDD(sc, size, sc.defaultParallelism, Utils.random.nextLong) + def uniformJavaRDD(jsc: JavaSparkContext, size: Long, numPartitions: Int): JavaDoubleRDD = { + JavaDoubleRDD.fromRDD(uniformRDD(jsc.sc, size, numPartitions)) } /** - * :: Experimental :: - * Generates an RDD comprised of i.i.d. samples from the standard normal distribution. - * - * To transform the distribution in the generated RDD from standard normal to some other normal - * N(mean, sigma), use `RandomRDDGenerators.normalRDD(sc, n, p, seed).map(v => mean + sigma * v)`. - * - * @param sc SparkContext used to create the RDD. - * @param size Size of the RDD. - * @param numPartitions Number of partitions in the RDD. - * @param seed Seed for the RNG that generates the seed for the generator in each partition. - * @return RDD[Double] comprised of i.i.d. samples ~ N(0.0, 1.0). + * [[RandomRDDs#uniformJavaRDD]] with the default number of partitions and the default seed. */ - @Experimental - def normalRDD(sc: SparkContext, size: Long, numPartitions: Int, seed: Long): RDD[Double] = { - val normal = new StandardNormalGenerator() - randomRDD(sc, normal, size, numPartitions, seed) + def uniformJavaRDD(jsc: JavaSparkContext, size: Long): JavaDoubleRDD = { + JavaDoubleRDD.fromRDD(uniformRDD(jsc.sc, size)) } /** - * :: Experimental :: * Generates an RDD comprised of i.i.d. samples from the standard normal distribution. * * To transform the distribution in the generated RDD from standard normal to some other normal - * N(mean, sigma), use `RandomRDDGenerators.normalRDD(sc, n, p).map(v => mean + sigma * v)`. + * `N(mean, sigma^2^)`, use `RandomRDDs.normalRDD(sc, n, p, seed).map(v => mean + sigma * v)`. * * @param sc SparkContext used to create the RDD. * @param size Size of the RDD. - * @param numPartitions Number of partitions in the RDD. + * @param numPartitions Number of partitions in the RDD (default: `sc.defaultParallelism`). + * @param seed Random seed (default: a random long integer). * @return RDD[Double] comprised of i.i.d. samples ~ N(0.0, 1.0). */ - @Experimental - def normalRDD(sc: SparkContext, size: Long, numPartitions: Int): RDD[Double] = { - normalRDD(sc, size, numPartitions, Utils.random.nextLong) + def normalRDD( + sc: SparkContext, + size: Long, + numPartitions: Int = 0, + seed: Long = Utils.random.nextLong()): RDD[Double] = { + val normal = new StandardNormalGenerator() + randomRDD(sc, normal, size, numPartitionsOrDefault(sc, numPartitions), seed) } /** - * :: Experimental :: - * Generates an RDD comprised of i.i.d. samples from the standard normal distribution. - * sc.defaultParallelism used for the number of partitions in the RDD. - * - * To transform the distribution in the generated RDD from standard normal to some other normal - * N(mean, sigma), use `RandomRDDGenerators.normalRDD(sc, n).map(v => mean + sigma * v)`. - * - * @param sc SparkContext used to create the RDD. - * @param size Size of the RDD. - * @return RDD[Double] comprised of i.i.d. samples ~ N(0.0, 1.0). + * Java-friendly version of [[RandomRDDs#normalRDD]]. */ - @Experimental - def normalRDD(sc: SparkContext, size: Long): RDD[Double] = { - normalRDD(sc, size, sc.defaultParallelism, Utils.random.nextLong) + def normalJavaRDD( + jsc: JavaSparkContext, + size: Long, + numPartitions: Int, + seed: Long): JavaDoubleRDD = { + JavaDoubleRDD.fromRDD(normalRDD(jsc.sc, size, numPartitions, seed)) } /** - * :: Experimental :: - * Generates an RDD comprised of i.i.d. samples from the Poisson distribution with the input mean. - * - * @param sc SparkContext used to create the RDD. - * @param mean Mean, or lambda, for the Poisson distribution. - * @param size Size of the RDD. - * @param numPartitions Number of partitions in the RDD. - * @param seed Seed for the RNG that generates the seed for the generator in each partition. - * @return RDD[Double] comprised of i.i.d. samples ~ Pois(mean). + * [[RandomRDDs#normalJavaRDD]] with the default seed. */ - @Experimental - def poissonRDD(sc: SparkContext, - mean: Double, - size: Long, - numPartitions: Int, - seed: Long): RDD[Double] = { - val poisson = new PoissonGenerator(mean) - randomRDD(sc, poisson, size, numPartitions, seed) + def normalJavaRDD(jsc: JavaSparkContext, size: Long, numPartitions: Int): JavaDoubleRDD = { + JavaDoubleRDD.fromRDD(normalRDD(jsc.sc, size, numPartitions)) } /** - * :: Experimental :: - * Generates an RDD comprised of i.i.d. samples from the Poisson distribution with the input mean. - * - * @param sc SparkContext used to create the RDD. - * @param mean Mean, or lambda, for the Poisson distribution. - * @param size Size of the RDD. - * @param numPartitions Number of partitions in the RDD. - * @return RDD[Double] comprised of i.i.d. samples ~ Pois(mean). + * [[RandomRDDs#normalJavaRDD]] with the default number of partitions and the default seed. */ - @Experimental - def poissonRDD(sc: SparkContext, mean: Double, size: Long, numPartitions: Int): RDD[Double] = { - poissonRDD(sc, mean, size, numPartitions, Utils.random.nextLong) + def normalJavaRDD(jsc: JavaSparkContext, size: Long): JavaDoubleRDD = { + JavaDoubleRDD.fromRDD(normalRDD(jsc.sc, size)) } /** - * :: Experimental :: * Generates an RDD comprised of i.i.d. samples from the Poisson distribution with the input mean. - * sc.defaultParallelism used for the number of partitions in the RDD. * * @param sc SparkContext used to create the RDD. * @param mean Mean, or lambda, for the Poisson distribution. * @param size Size of the RDD. + * @param numPartitions Number of partitions in the RDD (default: `sc.defaultParallelism`). + * @param seed Random seed (default: a random long integer). * @return RDD[Double] comprised of i.i.d. samples ~ Pois(mean). */ - @Experimental - def poissonRDD(sc: SparkContext, mean: Double, size: Long): RDD[Double] = { - poissonRDD(sc, mean, size, sc.defaultParallelism, Utils.random.nextLong) + def poissonRDD( + sc: SparkContext, + mean: Double, + size: Long, + numPartitions: Int = 0, + seed: Long = Utils.random.nextLong()): RDD[Double] = { + val poisson = new PoissonGenerator(mean) + randomRDD(sc, poisson, size, numPartitionsOrDefault(sc, numPartitions), seed) } /** - * :: Experimental :: - * Generates an RDD comprised of i.i.d. samples produced by the input DistributionGenerator. - * - * @param sc SparkContext used to create the RDD. - * @param generator DistributionGenerator used to populate the RDD. - * @param size Size of the RDD. - * @param numPartitions Number of partitions in the RDD. - * @param seed Seed for the RNG that generates the seed for the generator in each partition. - * @return RDD[Double] comprised of i.i.d. samples produced by generator. + * Java-friendly version of [[RandomRDDs#poissonRDD]]. */ - @Experimental - def randomRDD[T: ClassTag](sc: SparkContext, - generator: RandomDataGenerator[T], + def poissonJavaRDD( + jsc: JavaSparkContext, + mean: Double, size: Long, numPartitions: Int, - seed: Long): RDD[T] = { - new RandomRDD[T](sc, size, numPartitions, generator, seed) + seed: Long): JavaDoubleRDD = { + JavaDoubleRDD.fromRDD(poissonRDD(jsc.sc, mean, size, numPartitions, seed)) } /** - * :: Experimental :: - * Generates an RDD comprised of i.i.d. samples produced by the input DistributionGenerator. - * - * @param sc SparkContext used to create the RDD. - * @param generator DistributionGenerator used to populate the RDD. - * @param size Size of the RDD. - * @param numPartitions Number of partitions in the RDD. - * @return RDD[Double] comprised of i.i.d. samples produced by generator. + * [[RandomRDDs#poissonJavaRDD]] with the default seed. */ - @Experimental - def randomRDD[T: ClassTag](sc: SparkContext, - generator: RandomDataGenerator[T], + def poissonJavaRDD( + jsc: JavaSparkContext, + mean: Double, size: Long, - numPartitions: Int): RDD[T] = { - randomRDD[T](sc, generator, size, numPartitions, Utils.random.nextLong) + numPartitions: Int): JavaDoubleRDD = { + JavaDoubleRDD.fromRDD(poissonRDD(jsc.sc, mean, size, numPartitions)) } /** - * :: Experimental :: - * Generates an RDD comprised of i.i.d. samples produced by the input DistributionGenerator. - * sc.defaultParallelism used for the number of partitions in the RDD. + * [[RandomRDDs#poissonJavaRDD]] with the default number of partitions and the default seed. + */ + def poissonJavaRDD(jsc: JavaSparkContext, mean: Double, size: Long): JavaDoubleRDD = { + JavaDoubleRDD.fromRDD(poissonRDD(jsc.sc, mean, size)) + } + + /** + * :: DeveloperApi :: + * Generates an RDD comprised of i.i.d. samples produced by the input RandomDataGenerator. * * @param sc SparkContext used to create the RDD. - * @param generator DistributionGenerator used to populate the RDD. + * @param generator RandomDataGenerator used to populate the RDD. * @param size Size of the RDD. + * @param numPartitions Number of partitions in the RDD (default: `sc.defaultParallelism`). + * @param seed Random seed (default: a random long integer). * @return RDD[Double] comprised of i.i.d. samples produced by generator. */ - @Experimental - def randomRDD[T: ClassTag](sc: SparkContext, + @DeveloperApi + def randomRDD[T: ClassTag]( + sc: SparkContext, generator: RandomDataGenerator[T], - size: Long): RDD[T] = { - randomRDD[T](sc, generator, size, sc.defaultParallelism, Utils.random.nextLong) + size: Long, + numPartitions: Int = 0, + seed: Long = Utils.random.nextLong()): RDD[T] = { + new RandomRDD[T](sc, size, numPartitionsOrDefault(sc, numPartitions), generator, seed) } // TODO Generate RDD[Vector] from multivariate distributions. /** - * :: Experimental :: * Generates an RDD[Vector] with vectors containing i.i.d. samples drawn from the - * uniform distribution on [0.0 1.0]. + * uniform distribution on `U(0.0, 1.0)`. * * @param sc SparkContext used to create the RDD. * @param numRows Number of Vectors in the RDD. * @param numCols Number of elements in each Vector. * @param numPartitions Number of partitions in the RDD. * @param seed Seed for the RNG that generates the seed for the generator in each partition. - * @return RDD[Vector] with vectors containing i.i.d samples ~ U[0.0, 1.0]. + * @return RDD[Vector] with vectors containing i.i.d samples ~ `U(0.0, 1.0)`. */ - @Experimental - def uniformVectorRDD(sc: SparkContext, + def uniformVectorRDD( + sc: SparkContext, numRows: Long, numCols: Int, - numPartitions: Int, - seed: Long): RDD[Vector] = { + numPartitions: Int = 0, + seed: Long = Utils.random.nextLong()): RDD[Vector] = { val uniform = new UniformGenerator() - randomVectorRDD(sc, uniform, numRows, numCols, numPartitions, seed) + randomVectorRDD(sc, uniform, numRows, numCols, numPartitionsOrDefault(sc, numPartitions), seed) } /** - * :: Experimental :: - * Generates an RDD[Vector] with vectors containing i.i.d. samples drawn from the - * uniform distribution on [0.0 1.0]. - * - * @param sc SparkContext used to create the RDD. - * @param numRows Number of Vectors in the RDD. - * @param numCols Number of elements in each Vector. - * @param numPartitions Number of partitions in the RDD. - * @return RDD[Vector] with vectors containing i.i.d. samples ~ U[0.0, 1.0]. + * Java-friendly version of [[RandomRDDs#uniformVectorRDD]]. */ - @Experimental - def uniformVectorRDD(sc: SparkContext, + def uniformJavaVectorRDD( + jsc: JavaSparkContext, numRows: Long, numCols: Int, - numPartitions: Int): RDD[Vector] = { - uniformVectorRDD(sc, numRows, numCols, numPartitions, Utils.random.nextLong) + numPartitions: Int, + seed: Long): JavaRDD[Vector] = { + uniformVectorRDD(jsc.sc, numRows, numCols, numPartitions, seed).toJavaRDD() } /** - * :: Experimental :: - * Generates an RDD[Vector] with vectors containing i.i.d. samples drawn from the - * uniform distribution on [0.0 1.0]. - * sc.defaultParallelism used for the number of partitions in the RDD. - * - * @param sc SparkContext used to create the RDD. - * @param numRows Number of Vectors in the RDD. - * @param numCols Number of elements in each Vector. - * @return RDD[Vector] with vectors containing i.i.d. samples ~ U[0.0, 1.0]. + * [[RandomRDDs#uniformJavaVectorRDD]] with the default seed. */ - @Experimental - def uniformVectorRDD(sc: SparkContext, numRows: Long, numCols: Int): RDD[Vector] = { - uniformVectorRDD(sc, numRows, numCols, sc.defaultParallelism, Utils.random.nextLong) + def uniformJavaVectorRDD( + jsc: JavaSparkContext, + numRows: Long, + numCols: Int, + numPartitions: Int): JavaRDD[Vector] = { + uniformVectorRDD(jsc.sc, numRows, numCols, numPartitions).toJavaRDD() + } + + /** + * [[RandomRDDs#uniformJavaVectorRDD]] with the default number of partitions and the default seed. + */ + def uniformJavaVectorRDD( + jsc: JavaSparkContext, + numRows: Long, + numCols: Int): JavaRDD[Vector] = { + uniformVectorRDD(jsc.sc, numRows, numCols).toJavaRDD() } /** - * :: Experimental :: * Generates an RDD[Vector] with vectors containing i.i.d. samples drawn from the * standard normal distribution. * * @param sc SparkContext used to create the RDD. * @param numRows Number of Vectors in the RDD. * @param numCols Number of elements in each Vector. - * @param numPartitions Number of partitions in the RDD. - * @param seed Seed for the RNG that generates the seed for the generator in each partition. - * @return RDD[Vector] with vectors containing i.i.d. samples ~ N(0.0, 1.0). + * @param numPartitions Number of partitions in the RDD (default: `sc.defaultParallelism`). + * @param seed Random seed (default: a random long integer). + * @return RDD[Vector] with vectors containing i.i.d. samples ~ `N(0.0, 1.0)`. + */ + def normalVectorRDD( + sc: SparkContext, + numRows: Long, + numCols: Int, + numPartitions: Int = 0, + seed: Long = Utils.random.nextLong()): RDD[Vector] = { + val normal = new StandardNormalGenerator() + randomVectorRDD(sc, normal, numRows, numCols, numPartitionsOrDefault(sc, numPartitions), seed) + } + + /** + * Java-friendly version of [[RandomRDDs#normalVectorRDD]]. */ - @Experimental - def normalVectorRDD(sc: SparkContext, + def normalJavaVectorRDD( + jsc: JavaSparkContext, numRows: Long, numCols: Int, numPartitions: Int, - seed: Long): RDD[Vector] = { - val uniform = new StandardNormalGenerator() - randomVectorRDD(sc, uniform, numRows, numCols, numPartitions, seed) + seed: Long): JavaRDD[Vector] = { + normalVectorRDD(jsc.sc, numRows, numCols, numPartitions, seed).toJavaRDD() } /** - * :: Experimental :: - * Generates an RDD[Vector] with vectors containing i.i.d. samples drawn from the - * standard normal distribution. - * - * @param sc SparkContext used to create the RDD. - * @param numRows Number of Vectors in the RDD. - * @param numCols Number of elements in each Vector. - * @param numPartitions Number of partitions in the RDD. - * @return RDD[Vector] with vectors containing i.i.d. samples ~ N(0.0, 1.0). + * [[RandomRDDs#normalJavaVectorRDD]] with the default seed. */ - @Experimental - def normalVectorRDD(sc: SparkContext, + def normalJavaVectorRDD( + jsc: JavaSparkContext, numRows: Long, numCols: Int, - numPartitions: Int): RDD[Vector] = { - normalVectorRDD(sc, numRows, numCols, numPartitions, Utils.random.nextLong) + numPartitions: Int): JavaRDD[Vector] = { + normalVectorRDD(jsc.sc, numRows, numCols, numPartitions).toJavaRDD() } /** - * :: Experimental :: - * Generates an RDD[Vector] with vectors containing i.i.d. samples drawn from the - * standard normal distribution. - * sc.defaultParallelism used for the number of partitions in the RDD. - * - * @param sc SparkContext used to create the RDD. - * @param numRows Number of Vectors in the RDD. - * @param numCols Number of elements in each Vector. - * @return RDD[Vector] with vectors containing i.i.d. samples ~ N(0.0, 1.0). + * [[RandomRDDs#normalJavaVectorRDD]] with the default number of partitions and the default seed. */ - @Experimental - def normalVectorRDD(sc: SparkContext, numRows: Long, numCols: Int): RDD[Vector] = { - normalVectorRDD(sc, numRows, numCols, sc.defaultParallelism, Utils.random.nextLong) + def normalJavaVectorRDD( + jsc: JavaSparkContext, + numRows: Long, + numCols: Int): JavaRDD[Vector] = { + normalVectorRDD(jsc.sc, numRows, numCols).toJavaRDD() } /** - * :: Experimental :: * Generates an RDD[Vector] with vectors containing i.i.d. samples drawn from the * Poisson distribution with the input mean. * @@ -370,124 +315,85 @@ object RandomRDDs { * @param mean Mean, or lambda, for the Poisson distribution. * @param numRows Number of Vectors in the RDD. * @param numCols Number of elements in each Vector. - * @param numPartitions Number of partitions in the RDD. - * @param seed Seed for the RNG that generates the seed for the generator in each partition. + * @param numPartitions Number of partitions in the RDD (default: `sc.defaultParallelism`) + * @param seed Random seed (default: a random long integer). * @return RDD[Vector] with vectors containing i.i.d. samples ~ Pois(mean). */ - @Experimental - def poissonVectorRDD(sc: SparkContext, + def poissonVectorRDD( + sc: SparkContext, mean: Double, numRows: Long, numCols: Int, - numPartitions: Int, - seed: Long): RDD[Vector] = { + numPartitions: Int = 0, + seed: Long = Utils.random.nextLong()): RDD[Vector] = { val poisson = new PoissonGenerator(mean) - randomVectorRDD(sc, poisson, numRows, numCols, numPartitions, seed) + randomVectorRDD(sc, poisson, numRows, numCols, numPartitionsOrDefault(sc, numPartitions), seed) } /** - * :: Experimental :: - * Generates an RDD[Vector] with vectors containing i.i.d. samples drawn from the - * Poisson distribution with the input mean. - * - * @param sc SparkContext used to create the RDD. - * @param mean Mean, or lambda, for the Poisson distribution. - * @param numRows Number of Vectors in the RDD. - * @param numCols Number of elements in each Vector. - * @param numPartitions Number of partitions in the RDD. - * @return RDD[Vector] with vectors containing i.i.d. samples ~ Pois(mean). + * Java-friendly version of [[RandomRDDs#poissonVectorRDD]]. */ - @Experimental - def poissonVectorRDD(sc: SparkContext, + def poissonJavaVectorRDD( + jsc: JavaSparkContext, mean: Double, numRows: Long, numCols: Int, - numPartitions: Int): RDD[Vector] = { - poissonVectorRDD(sc, mean, numRows, numCols, numPartitions, Utils.random.nextLong) + numPartitions: Int, + seed: Long): JavaRDD[Vector] = { + poissonVectorRDD(jsc.sc, mean, numRows, numCols, numPartitions, seed).toJavaRDD() } /** - * :: Experimental :: - * Generates an RDD[Vector] with vectors containing i.i.d. samples drawn from the - * Poisson distribution with the input mean. - * sc.defaultParallelism used for the number of partitions in the RDD. - * - * @param sc SparkContext used to create the RDD. - * @param mean Mean, or lambda, for the Poisson distribution. - * @param numRows Number of Vectors in the RDD. - * @param numCols Number of elements in each Vector. - * @return RDD[Vector] with vectors containing i.i.d. samples ~ Pois(mean). + * [[RandomRDDs#poissonJavaVectorRDD]] with the default seed. */ - @Experimental - def poissonVectorRDD(sc: SparkContext, + def poissonJavaVectorRDD( + jsc: JavaSparkContext, mean: Double, numRows: Long, - numCols: Int): RDD[Vector] = { - poissonVectorRDD(sc, mean, numRows, numCols, sc.defaultParallelism, Utils.random.nextLong) + numCols: Int, + numPartitions: Int): JavaRDD[Vector] = { + poissonVectorRDD(jsc.sc, mean, numRows, numCols, numPartitions).toJavaRDD() } /** - * :: Experimental :: - * Generates an RDD[Vector] with vectors containing i.i.d. samples produced by the - * input DistributionGenerator. - * - * @param sc SparkContext used to create the RDD. - * @param generator DistributionGenerator used to populate the RDD. - * @param numRows Number of Vectors in the RDD. - * @param numCols Number of elements in each Vector. - * @param numPartitions Number of partitions in the RDD. - * @param seed Seed for the RNG that generates the seed for the generator in each partition. - * @return RDD[Vector] with vectors containing i.i.d. samples produced by generator. + * [[RandomRDDs#poissonJavaVectorRDD]] with the default number of partitions and the default seed. */ - @Experimental - def randomVectorRDD(sc: SparkContext, - generator: RandomDataGenerator[Double], + def poissonJavaVectorRDD( + jsc: JavaSparkContext, + mean: Double, numRows: Long, - numCols: Int, - numPartitions: Int, - seed: Long): RDD[Vector] = { - new RandomVectorRDD(sc, numRows, numCols, numPartitions, generator, seed) + numCols: Int): JavaRDD[Vector] = { + poissonVectorRDD(jsc.sc, mean, numRows, numCols).toJavaRDD() } /** - * :: Experimental :: + * :: DeveloperApi :: * Generates an RDD[Vector] with vectors containing i.i.d. samples produced by the - * input DistributionGenerator. + * input RandomDataGenerator. * * @param sc SparkContext used to create the RDD. - * @param generator DistributionGenerator used to populate the RDD. + * @param generator RandomDataGenerator used to populate the RDD. * @param numRows Number of Vectors in the RDD. * @param numCols Number of elements in each Vector. - * @param numPartitions Number of partitions in the RDD. + * @param numPartitions Number of partitions in the RDD (default: `sc.defaultParallelism`). + * @param seed Random seed (default: a random long integer). * @return RDD[Vector] with vectors containing i.i.d. samples produced by generator. */ - @Experimental + @DeveloperApi def randomVectorRDD(sc: SparkContext, generator: RandomDataGenerator[Double], numRows: Long, numCols: Int, - numPartitions: Int): RDD[Vector] = { - randomVectorRDD(sc, generator, numRows, numCols, numPartitions, Utils.random.nextLong) + numPartitions: Int = 0, + seed: Long = Utils.random.nextLong()): RDD[Vector] = { + new RandomVectorRDD( + sc, numRows, numCols, numPartitionsOrDefault(sc, numPartitions), generator, seed) } /** - * :: Experimental :: - * Generates an RDD[Vector] with vectors containing i.i.d. samples produced by the - * input DistributionGenerator. - * sc.defaultParallelism used for the number of partitions in the RDD. - * - * @param sc SparkContext used to create the RDD. - * @param generator DistributionGenerator used to populate the RDD. - * @param numRows Number of Vectors in the RDD. - * @param numCols Number of elements in each Vector. - * @return RDD[Vector] with vectors containing i.i.d. samples produced by generator. + * Returns `numPartitions` if it is positive, or `sc.defaultParallelism` otherwise. */ - @Experimental - def randomVectorRDD(sc: SparkContext, - generator: RandomDataGenerator[Double], - numRows: Long, - numCols: Int): RDD[Vector] = { - randomVectorRDD(sc, generator, numRows, numCols, - sc.defaultParallelism, Utils.random.nextLong) + private def numPartitionsOrDefault(sc: SparkContext, numPartitions: Int): Int = { + if (numPartitions > 0) numPartitions else sc.defaultMinPartitions } } diff --git a/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java b/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java new file mode 100644 index 0000000000000..a725736ca1a58 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.random; + +import com.google.common.collect.Lists; +import org.apache.spark.api.java.JavaRDD; +import org.junit.Assert; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaDoubleRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.linalg.Vector; +import static org.apache.spark.mllib.random.RandomRDDs.*; + +public class JavaRandomRDDsSuite { + private transient JavaSparkContext sc; + + @Before + public void setUp() { + sc = new JavaSparkContext("local", "JavaRandomRDDsSuite"); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + } + + @Test + public void testUniformRDD() { + long m = 1000L; + int p = 2; + long seed = 1L; + JavaDoubleRDD rdd1 = uniformJavaRDD(sc, m); + JavaDoubleRDD rdd2 = uniformJavaRDD(sc, m, p); + JavaDoubleRDD rdd3 = uniformJavaRDD(sc, m, p, seed); + for (JavaDoubleRDD rdd: Lists.newArrayList(rdd1, rdd2, rdd3)) { + Assert.assertEquals(m, rdd.count()); + } + } + + @Test + public void testNormalRDD() { + long m = 1000L; + int p = 2; + long seed = 1L; + JavaDoubleRDD rdd1 = normalJavaRDD(sc, m); + JavaDoubleRDD rdd2 = normalJavaRDD(sc, m, p); + JavaDoubleRDD rdd3 = normalJavaRDD(sc, m, p, seed); + for (JavaDoubleRDD rdd: Lists.newArrayList(rdd1, rdd2, rdd3)) { + Assert.assertEquals(m, rdd.count()); + } + } + + @Test + public void testPoissonRDD() { + double mean = 2.0; + long m = 1000L; + int p = 2; + long seed = 1L; + JavaDoubleRDD rdd1 = poissonJavaRDD(sc, mean, m); + JavaDoubleRDD rdd2 = poissonJavaRDD(sc, mean, m, p); + JavaDoubleRDD rdd3 = poissonJavaRDD(sc, mean, m, p, seed); + for (JavaDoubleRDD rdd: Lists.newArrayList(rdd1, rdd2, rdd3)) { + Assert.assertEquals(m, rdd.count()); + } + } + + @Test + @SuppressWarnings("unchecked") + public void testUniformVectorRDD() { + long m = 100L; + int n = 10; + int p = 2; + long seed = 1L; + JavaRDD rdd1 = uniformJavaVectorRDD(sc, m, n); + JavaRDD rdd2 = uniformJavaVectorRDD(sc, m, n, p); + JavaRDD rdd3 = uniformJavaVectorRDD(sc, m, n, p, seed); + for (JavaRDD rdd: Lists.newArrayList(rdd1, rdd2, rdd3)) { + Assert.assertEquals(m, rdd.count()); + Assert.assertEquals(n, rdd.first().size()); + } + } + + @Test + @SuppressWarnings("unchecked") + public void testNormalVectorRDD() { + long m = 100L; + int n = 10; + int p = 2; + long seed = 1L; + JavaRDD rdd1 = normalJavaVectorRDD(sc, m, n); + JavaRDD rdd2 = normalJavaVectorRDD(sc, m, n, p); + JavaRDD rdd3 = normalJavaVectorRDD(sc, m, n, p, seed); + for (JavaRDD rdd: Lists.newArrayList(rdd1, rdd2, rdd3)) { + Assert.assertEquals(m, rdd.count()); + Assert.assertEquals(n, rdd.first().size()); + } + } + + @Test + @SuppressWarnings("unchecked") + public void testPoissonVectorRDD() { + double mean = 2.0; + long m = 100L; + int n = 10; + int p = 2; + long seed = 1L; + JavaRDD rdd1 = poissonJavaVectorRDD(sc, mean, m, n); + JavaRDD rdd2 = poissonJavaVectorRDD(sc, mean, m, n, p); + JavaRDD rdd3 = poissonJavaVectorRDD(sc, mean, m, n, p, seed); + for (JavaRDD rdd: Lists.newArrayList(rdd1, rdd2, rdd3)) { + Assert.assertEquals(m, rdd.count()); + Assert.assertEquals(n, rdd.first().size()); + } + } +} diff --git a/python/pyspark/mllib/random.py b/python/pyspark/mllib/random.py index 3f3b19053d32e..4dc1a4a912421 100644 --- a/python/pyspark/mllib/random.py +++ b/python/pyspark/mllib/random.py @@ -35,10 +35,10 @@ class RandomRDDs: def uniformRDD(sc, size, numPartitions=None, seed=None): """ Generates an RDD comprised of i.i.d. samples from the - uniform distribution on [0.0, 1.0]. + uniform distribution U(0.0, 1.0). - To transform the distribution in the generated RDD from U[0.0, 1.0] - to U[a, b], use + To transform the distribution in the generated RDD from U(0.0, 1.0) + to U(a, b), use C{RandomRDDs.uniformRDD(sc, n, p, seed)\ .map(lambda v: a + (b - a) * v)} @@ -60,11 +60,11 @@ def uniformRDD(sc, size, numPartitions=None, seed=None): @staticmethod def normalRDD(sc, size, numPartitions=None, seed=None): """ - Generates an RDD comprised of i.i.d samples from the standard normal + Generates an RDD comprised of i.i.d. samples from the standard normal distribution. To transform the distribution in the generated RDD from standard normal - to some other normal N(mean, sigma), use + to some other normal N(mean, sigma^2), use C{RandomRDDs.normal(sc, n, p, seed)\ .map(lambda v: mean + sigma * v)} @@ -84,7 +84,7 @@ def normalRDD(sc, size, numPartitions=None, seed=None): @staticmethod def poissonRDD(sc, mean, size, numPartitions=None, seed=None): """ - Generates an RDD comprised of i.i.d samples from the Poisson + Generates an RDD comprised of i.i.d. samples from the Poisson distribution with the input mean. >>> mean = 100.0 @@ -105,8 +105,8 @@ def poissonRDD(sc, mean, size, numPartitions=None, seed=None): @staticmethod def uniformVectorRDD(sc, numRows, numCols, numPartitions=None, seed=None): """ - Generates an RDD comprised of vectors containing i.i.d samples drawn - from the uniform distribution on [0.0 1.0]. + Generates an RDD comprised of vectors containing i.i.d. samples drawn + from the uniform distribution U(0.0, 1.0). >>> import numpy as np >>> mat = np.matrix(RandomRDDs.uniformVectorRDD(sc, 10, 10).collect()) @@ -125,7 +125,7 @@ def uniformVectorRDD(sc, numRows, numCols, numPartitions=None, seed=None): @staticmethod def normalVectorRDD(sc, numRows, numCols, numPartitions=None, seed=None): """ - Generates an RDD comprised of vectors containing i.i.d samples drawn + Generates an RDD comprised of vectors containing i.i.d. samples drawn from the standard normal distribution. >>> import numpy as np @@ -145,7 +145,7 @@ def normalVectorRDD(sc, numRows, numCols, numPartitions=None, seed=None): @staticmethod def poissonVectorRDD(sc, mean, numRows, numCols, numPartitions=None, seed=None): """ - Generates an RDD comprised of vectors containing i.i.d samples drawn + Generates an RDD comprised of vectors containing i.i.d. samples drawn from the Poisson distribution with the input mean. >>> import numpy as np From 66b4c81db7e826c00f7fb449b8a8af810cf7dd9a Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 19 Aug 2014 17:40:35 -0700 Subject: [PATCH 197/231] [SPARK-2468] Netty based block server / client module Previous pull request (#1907) was reverted. This brings it back. Still looking into the hang. Author: Reynold Xin Closes #1971 from rxin/netty1 and squashes the following commits: b0be96f [Reynold Xin] Added test to make sure outstandingRequests are cleaned after firing the events. 4c6d0ee [Reynold Xin] Pass callbacks cleanly. 603dce7 [Reynold Xin] Upgrade Netty to 4.0.23 to fix the DefaultFileRegion bug. 88be1d4 [Reynold Xin] Downgrade to 4.0.21 to work around a bug in writing DefaultFileRegion. 002626a [Reynold Xin] Remove netty-test-file.txt. db6e6e0 [Reynold Xin] Revert "Revert "[SPARK-2468] Netty based block server / client module"" (cherry picked from commit 8b9dc991018842e01f4b93870a2bc2c2cb9ea4ba) Signed-off-by: Reynold Xin --- .../spark/network/netty/FileClient.scala | 85 --------- .../network/netty/FileClientHandler.scala | 50 ------ .../spark/network/netty/FileHeader.scala | 71 -------- .../spark/network/netty/FileServer.scala | 91 ---------- .../network/netty/FileServerHandler.scala | 68 -------- .../spark/network/netty/NettyConfig.scala | 59 +++++++ .../spark/network/netty/ShuffleCopier.scala | 118 ------------- .../spark/network/netty/ShuffleSender.scala | 71 -------- .../BlockClientListener.scala} | 16 +- .../netty/client/BlockFetchingClient.scala | 132 ++++++++++++++ .../client/BlockFetchingClientFactory.scala | 99 +++++++++++ .../client/BlockFetchingClientHandler.scala | 103 +++++++++++ .../netty/client/LazyInitIterator.scala | 44 +++++ .../netty/client/ReferenceCountedBuffer.scala | 47 +++++ .../network/netty/server/BlockHeader.scala | 32 ++++ .../netty/server/BlockHeaderEncoder.scala | 47 +++++ .../network/netty/server/BlockServer.scala | 162 ++++++++++++++++++ .../BlockServerChannelInitializer.scala} | 22 ++- .../netty/server/BlockServerHandler.scala | 140 +++++++++++++++ .../spark/storage/BlockDataProvider.scala | 32 ++++ .../spark/storage/BlockFetcherIterator.scala | 137 +++++++-------- .../apache/spark/storage/BlockManager.scala | 49 +++++- .../storage/BlockNotFoundException.scala | 21 +++ .../spark/storage/DiskBlockManager.scala | 13 +- .../netty/ServerClientIntegrationSuite.scala | 161 +++++++++++++++++ .../BlockFetchingClientHandlerSuite.scala | 105 ++++++++++++ .../server/BlockHeaderEncoderSuite.scala | 64 +++++++ .../server/BlockServerHandlerSuite.scala | 107 ++++++++++++ pom.xml | 2 +- 29 files changed, 1484 insertions(+), 664 deletions(-) delete mode 100644 core/src/main/scala/org/apache/spark/network/netty/FileClient.scala delete mode 100644 core/src/main/scala/org/apache/spark/network/netty/FileClientHandler.scala delete mode 100644 core/src/main/scala/org/apache/spark/network/netty/FileHeader.scala delete mode 100644 core/src/main/scala/org/apache/spark/network/netty/FileServer.scala delete mode 100644 core/src/main/scala/org/apache/spark/network/netty/FileServerHandler.scala create mode 100644 core/src/main/scala/org/apache/spark/network/netty/NettyConfig.scala delete mode 100644 core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala delete mode 100644 core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala rename core/src/main/scala/org/apache/spark/network/netty/{FileClientChannelInitializer.scala => client/BlockClientListener.scala} (65%) create mode 100644 core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClient.scala create mode 100644 core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClientFactory.scala create mode 100644 core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClientHandler.scala create mode 100644 core/src/main/scala/org/apache/spark/network/netty/client/LazyInitIterator.scala create mode 100644 core/src/main/scala/org/apache/spark/network/netty/client/ReferenceCountedBuffer.scala create mode 100644 core/src/main/scala/org/apache/spark/network/netty/server/BlockHeader.scala create mode 100644 core/src/main/scala/org/apache/spark/network/netty/server/BlockHeaderEncoder.scala create mode 100644 core/src/main/scala/org/apache/spark/network/netty/server/BlockServer.scala rename core/src/main/scala/org/apache/spark/network/netty/{FileServerChannelInitializer.scala => server/BlockServerChannelInitializer.scala} (58%) create mode 100644 core/src/main/scala/org/apache/spark/network/netty/server/BlockServerHandler.scala create mode 100644 core/src/main/scala/org/apache/spark/storage/BlockDataProvider.scala create mode 100644 core/src/main/scala/org/apache/spark/storage/BlockNotFoundException.scala create mode 100644 core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/network/netty/client/BlockFetchingClientHandlerSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/network/netty/server/BlockHeaderEncoderSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/network/netty/server/BlockServerHandlerSuite.scala diff --git a/core/src/main/scala/org/apache/spark/network/netty/FileClient.scala b/core/src/main/scala/org/apache/spark/network/netty/FileClient.scala deleted file mode 100644 index c6d35f73db545..0000000000000 --- a/core/src/main/scala/org/apache/spark/network/netty/FileClient.scala +++ /dev/null @@ -1,85 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty - -import java.util.concurrent.TimeUnit - -import io.netty.bootstrap.Bootstrap -import io.netty.channel.{Channel, ChannelOption, EventLoopGroup} -import io.netty.channel.oio.OioEventLoopGroup -import io.netty.channel.socket.oio.OioSocketChannel - -import org.apache.spark.Logging - -class FileClient(handler: FileClientHandler, connectTimeout: Int) extends Logging { - - private var channel: Channel = _ - private var bootstrap: Bootstrap = _ - private var group: EventLoopGroup = _ - private val sendTimeout = 60 - - def init(): Unit = { - group = new OioEventLoopGroup - bootstrap = new Bootstrap - bootstrap.group(group) - .channel(classOf[OioSocketChannel]) - .option(ChannelOption.SO_KEEPALIVE, java.lang.Boolean.TRUE) - .option(ChannelOption.TCP_NODELAY, java.lang.Boolean.TRUE) - .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, Integer.valueOf(connectTimeout)) - .handler(new FileClientChannelInitializer(handler)) - } - - def connect(host: String, port: Int) { - try { - channel = bootstrap.connect(host, port).sync().channel() - } catch { - case e: InterruptedException => - logWarning("FileClient interrupted while trying to connect", e) - close() - } - } - - def waitForClose(): Unit = { - try { - channel.closeFuture.sync() - } catch { - case e: InterruptedException => - logWarning("FileClient interrupted", e) - } - } - - def sendRequest(file: String): Unit = { - try { - val bSent = channel.writeAndFlush(file + "\r\n").await(sendTimeout, TimeUnit.SECONDS) - if (!bSent) { - throw new RuntimeException("Failed to send") - } - } catch { - case e: InterruptedException => - logError("Error", e) - } - } - - def close(): Unit = { - if (group != null) { - group.shutdownGracefully() - group = null - bootstrap = null - } - } -} diff --git a/core/src/main/scala/org/apache/spark/network/netty/FileClientHandler.scala b/core/src/main/scala/org/apache/spark/network/netty/FileClientHandler.scala deleted file mode 100644 index 017302ec7d33d..0000000000000 --- a/core/src/main/scala/org/apache/spark/network/netty/FileClientHandler.scala +++ /dev/null @@ -1,50 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty - -import io.netty.buffer.ByteBuf -import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler} - -import org.apache.spark.storage.BlockId - - -abstract class FileClientHandler extends SimpleChannelInboundHandler[ByteBuf] { - - private var currentHeader: FileHeader = null - - @volatile - private var handlerCalled: Boolean = false - - def isComplete: Boolean = handlerCalled - - def handle(ctx: ChannelHandlerContext, in: ByteBuf, header: FileHeader) - - def handleError(blockId: BlockId) - - override def channelRead0(ctx: ChannelHandlerContext, in: ByteBuf) { - if (currentHeader == null && in.readableBytes >= FileHeader.HEADER_SIZE) { - currentHeader = FileHeader.create(in.readBytes(FileHeader.HEADER_SIZE)) - } - if (in.readableBytes >= currentHeader.fileLen) { - handle(ctx, in, currentHeader) - handlerCalled = true - currentHeader = null - ctx.close() - } - } -} diff --git a/core/src/main/scala/org/apache/spark/network/netty/FileHeader.scala b/core/src/main/scala/org/apache/spark/network/netty/FileHeader.scala deleted file mode 100644 index 607e560ff277f..0000000000000 --- a/core/src/main/scala/org/apache/spark/network/netty/FileHeader.scala +++ /dev/null @@ -1,71 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty - -import io.netty.buffer._ - -import org.apache.spark.Logging -import org.apache.spark.storage.{BlockId, TestBlockId} - -private[spark] class FileHeader ( - val fileLen: Int, - val blockId: BlockId) extends Logging { - - lazy val buffer: ByteBuf = { - val buf = Unpooled.buffer() - buf.capacity(FileHeader.HEADER_SIZE) - buf.writeInt(fileLen) - buf.writeInt(blockId.name.length) - blockId.name.foreach((x: Char) => buf.writeByte(x)) - // padding the rest of header - if (FileHeader.HEADER_SIZE - buf.readableBytes > 0 ) { - buf.writeZero(FileHeader.HEADER_SIZE - buf.readableBytes) - } else { - throw new Exception("too long header " + buf.readableBytes) - logInfo("too long header") - } - buf - } - -} - -private[spark] object FileHeader { - - val HEADER_SIZE = 40 - - def getFileLenOffset = 0 - def getFileLenSize = Integer.SIZE/8 - - def create(buf: ByteBuf): FileHeader = { - val length = buf.readInt - val idLength = buf.readInt - val idBuilder = new StringBuilder(idLength) - for (i <- 1 to idLength) { - idBuilder += buf.readByte().asInstanceOf[Char] - } - val blockId = BlockId(idBuilder.toString()) - new FileHeader(length, blockId) - } - - def main(args:Array[String]) { - val header = new FileHeader(25, TestBlockId("my_block")) - val buf = header.buffer - val newHeader = FileHeader.create(buf) - System.out.println("id=" + newHeader.blockId + ",size=" + newHeader.fileLen) - } -} diff --git a/core/src/main/scala/org/apache/spark/network/netty/FileServer.scala b/core/src/main/scala/org/apache/spark/network/netty/FileServer.scala deleted file mode 100644 index dff77950659af..0000000000000 --- a/core/src/main/scala/org/apache/spark/network/netty/FileServer.scala +++ /dev/null @@ -1,91 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty - -import java.net.InetSocketAddress - -import io.netty.bootstrap.ServerBootstrap -import io.netty.channel.{ChannelFuture, ChannelOption, EventLoopGroup} -import io.netty.channel.oio.OioEventLoopGroup -import io.netty.channel.socket.oio.OioServerSocketChannel - -import org.apache.spark.Logging - -/** - * Server that accept the path of a file an echo back its content. - */ -class FileServer(pResolver: PathResolver, private var port: Int) extends Logging { - - private val addr: InetSocketAddress = new InetSocketAddress(port) - private var bossGroup: EventLoopGroup = new OioEventLoopGroup - private var workerGroup: EventLoopGroup = new OioEventLoopGroup - - private var channelFuture: ChannelFuture = { - val bootstrap = new ServerBootstrap - bootstrap.group(bossGroup, workerGroup) - .channel(classOf[OioServerSocketChannel]) - .option(ChannelOption.SO_BACKLOG, java.lang.Integer.valueOf(100)) - .option(ChannelOption.SO_RCVBUF, java.lang.Integer.valueOf(1500)) - .childHandler(new FileServerChannelInitializer(pResolver)) - bootstrap.bind(addr) - } - - try { - val boundAddress = channelFuture.sync.channel.localAddress.asInstanceOf[InetSocketAddress] - port = boundAddress.getPort - } catch { - case ie: InterruptedException => - port = 0 - } - - /** Start the file server asynchronously in a new thread. */ - def start(): Unit = { - val blockingThread: Thread = new Thread { - override def run(): Unit = { - try { - channelFuture.channel.closeFuture.sync - logInfo("FileServer exiting") - } catch { - case e: InterruptedException => - logError("File server start got interrupted", e) - } - // NOTE: bootstrap is shutdown in stop() - } - } - blockingThread.setDaemon(true) - blockingThread.start() - } - - def getPort: Int = port - - def stop(): Unit = { - if (channelFuture != null) { - channelFuture.channel().close().awaitUninterruptibly() - channelFuture = null - } - if (bossGroup != null) { - bossGroup.shutdownGracefully() - bossGroup = null - } - if (workerGroup != null) { - workerGroup.shutdownGracefully() - workerGroup = null - } - } -} - diff --git a/core/src/main/scala/org/apache/spark/network/netty/FileServerHandler.scala b/core/src/main/scala/org/apache/spark/network/netty/FileServerHandler.scala deleted file mode 100644 index 96f60b2883ad9..0000000000000 --- a/core/src/main/scala/org/apache/spark/network/netty/FileServerHandler.scala +++ /dev/null @@ -1,68 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty - -import java.io.FileInputStream - -import io.netty.channel.{DefaultFileRegion, ChannelHandlerContext, SimpleChannelInboundHandler} - -import org.apache.spark.Logging -import org.apache.spark.storage.{BlockId, FileSegment} - - -class FileServerHandler(pResolver: PathResolver) - extends SimpleChannelInboundHandler[String] with Logging { - - override def channelRead0(ctx: ChannelHandlerContext, blockIdString: String): Unit = { - val blockId: BlockId = BlockId(blockIdString) - val fileSegment: FileSegment = pResolver.getBlockLocation(blockId) - if (fileSegment == null) { - return - } - val file = fileSegment.file - if (file.exists) { - if (!file.isFile) { - ctx.write(new FileHeader(0, blockId).buffer) - ctx.flush() - return - } - val length: Long = fileSegment.length - if (length > Integer.MAX_VALUE || length <= 0) { - ctx.write(new FileHeader(0, blockId).buffer) - ctx.flush() - return - } - ctx.write(new FileHeader(length.toInt, blockId).buffer) - try { - val channel = new FileInputStream(file).getChannel - ctx.write(new DefaultFileRegion(channel, fileSegment.offset, fileSegment.length)) - } catch { - case e: Exception => - logError("Exception: ", e) - } - } else { - ctx.write(new FileHeader(0, blockId).buffer) - } - ctx.flush() - } - - override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = { - logError("Exception: ", cause) - ctx.close() - } -} diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyConfig.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyConfig.scala new file mode 100644 index 0000000000000..b5870152c5a64 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyConfig.scala @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty + +import org.apache.spark.SparkConf + +/** + * A central location that tracks all the settings we exposed to users. + */ +private[spark] +class NettyConfig(conf: SparkConf) { + + /** Port the server listens on. Default to a random port. */ + private[netty] val serverPort = conf.getInt("spark.shuffle.io.port", 0) + + /** IO mode: nio, oio, epoll, or auto (try epoll first and then nio). */ + private[netty] val ioMode = conf.get("spark.shuffle.io.mode", "nio").toLowerCase + + /** Connect timeout in secs. Default 60 secs. */ + private[netty] val connectTimeoutMs = conf.getInt("spark.shuffle.io.connectionTimeout", 60) * 1000 + + /** + * Percentage of the desired amount of time spent for I/O in the child event loops. + * Only applicable in nio and epoll. + */ + private[netty] val ioRatio = conf.getInt("spark.shuffle.io.netty.ioRatio", 80) + + /** Requested maximum length of the queue of incoming connections. */ + private[netty] val backLog: Option[Int] = conf.getOption("spark.shuffle.io.backLog").map(_.toInt) + + /** + * Receive buffer size (SO_RCVBUF). + * Note: the optimal size for receive buffer and send buffer should be + * latency * network_bandwidth. + * Assuming latency = 1ms, network_bandwidth = 10Gbps + * buffer size should be ~ 1.25MB + */ + private[netty] val receiveBuf: Option[Int] = + conf.getOption("spark.shuffle.io.sendBuffer").map(_.toInt) + + /** Send buffer size (SO_SNDBUF). */ + private[netty] val sendBuf: Option[Int] = + conf.getOption("spark.shuffle.io.sendBuffer").map(_.toInt) +} diff --git a/core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala b/core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala deleted file mode 100644 index e7b2855e1ec91..0000000000000 --- a/core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala +++ /dev/null @@ -1,118 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty - -import java.util.concurrent.Executors - -import scala.collection.JavaConverters._ - -import io.netty.buffer.ByteBuf -import io.netty.channel.ChannelHandlerContext -import io.netty.util.CharsetUtil - -import org.apache.spark.{Logging, SparkConf} -import org.apache.spark.network.ConnectionManagerId -import org.apache.spark.storage.BlockId - -private[spark] class ShuffleCopier(conf: SparkConf) extends Logging { - - def getBlock(host: String, port: Int, blockId: BlockId, - resultCollectCallback: (BlockId, Long, ByteBuf) => Unit) { - - val handler = new ShuffleCopier.ShuffleClientHandler(resultCollectCallback) - val connectTimeout = conf.getInt("spark.shuffle.netty.connect.timeout", 60000) - val fc = new FileClient(handler, connectTimeout) - - try { - fc.init() - fc.connect(host, port) - fc.sendRequest(blockId.name) - fc.waitForClose() - fc.close() - } catch { - // Handle any socket-related exceptions in FileClient - case e: Exception => { - logError("Shuffle copy of block " + blockId + " from " + host + ":" + port + " failed", e) - handler.handleError(blockId) - } - } - } - - def getBlock(cmId: ConnectionManagerId, blockId: BlockId, - resultCollectCallback: (BlockId, Long, ByteBuf) => Unit) { - getBlock(cmId.host, cmId.port, blockId, resultCollectCallback) - } - - def getBlocks(cmId: ConnectionManagerId, - blocks: Seq[(BlockId, Long)], - resultCollectCallback: (BlockId, Long, ByteBuf) => Unit) { - - for ((blockId, size) <- blocks) { - getBlock(cmId, blockId, resultCollectCallback) - } - } -} - - -private[spark] object ShuffleCopier extends Logging { - - private class ShuffleClientHandler(resultCollectCallBack: (BlockId, Long, ByteBuf) => Unit) - extends FileClientHandler with Logging { - - override def handle(ctx: ChannelHandlerContext, in: ByteBuf, header: FileHeader) { - logDebug("Received Block: " + header.blockId + " (" + header.fileLen + "B)") - resultCollectCallBack(header.blockId, header.fileLen.toLong, in.readBytes(header.fileLen)) - } - - override def handleError(blockId: BlockId) { - if (!isComplete) { - resultCollectCallBack(blockId, -1, null) - } - } - } - - def echoResultCollectCallBack(blockId: BlockId, size: Long, content: ByteBuf) { - if (size != -1) { - logInfo("File: " + blockId + " content is : \" " + content.toString(CharsetUtil.UTF_8) + "\"") - } - } - - def main(args: Array[String]) { - if (args.length < 3) { - System.err.println("Usage: ShuffleCopier ") - System.exit(1) - } - val host = args(0) - val port = args(1).toInt - val blockId = BlockId(args(2)) - val threads = if (args.length > 3) args(3).toInt else 10 - - val copiers = Executors.newFixedThreadPool(80) - val tasks = (for (i <- Range(0, threads)) yield { - Executors.callable(new Runnable() { - def run() { - val copier = new ShuffleCopier(new SparkConf) - copier.getBlock(host, port, blockId, echoResultCollectCallBack) - } - }) - }).asJava - copiers.invokeAll(tasks) - copiers.shutdown() - System.exit(0) - } -} diff --git a/core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala b/core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala deleted file mode 100644 index 95958e30f7eeb..0000000000000 --- a/core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala +++ /dev/null @@ -1,71 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty - -import java.io.File - -import org.apache.spark.Logging -import org.apache.spark.util.Utils -import org.apache.spark.storage.{BlockId, FileSegment} - -private[spark] class ShuffleSender(portIn: Int, val pResolver: PathResolver) extends Logging { - - val server = new FileServer(pResolver, portIn) - server.start() - - def stop() { - server.stop() - } - - def port: Int = server.getPort -} - - -/** - * An application for testing the shuffle sender as a standalone program. - */ -private[spark] object ShuffleSender { - - def main(args: Array[String]) { - if (args.length < 3) { - System.err.println( - "Usage: ShuffleSender ") - System.exit(1) - } - - val port = args(0).toInt - val subDirsPerLocalDir = args(1).toInt - val localDirs = args.drop(2).map(new File(_)) - - val pResovler = new PathResolver { - override def getBlockLocation(blockId: BlockId): FileSegment = { - if (!blockId.isShuffle) { - throw new Exception("Block " + blockId + " is not a shuffle block") - } - // Figure out which local directory it hashes to, and which subdirectory in that - val hash = Utils.nonNegativeHash(blockId) - val dirId = hash % localDirs.length - val subDirId = (hash / localDirs.length) % subDirsPerLocalDir - val subDir = new File(localDirs(dirId), "%02x".format(subDirId)) - val file = new File(subDir, blockId.name) - new FileSegment(file, 0, file.length()) - } - } - val sender = new ShuffleSender(port, pResovler) - } -} diff --git a/core/src/main/scala/org/apache/spark/network/netty/FileClientChannelInitializer.scala b/core/src/main/scala/org/apache/spark/network/netty/client/BlockClientListener.scala similarity index 65% rename from core/src/main/scala/org/apache/spark/network/netty/FileClientChannelInitializer.scala rename to core/src/main/scala/org/apache/spark/network/netty/client/BlockClientListener.scala index f4261c13f70a8..e28219dd7745b 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/FileClientChannelInitializer.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/client/BlockClientListener.scala @@ -15,17 +15,15 @@ * limitations under the License. */ -package org.apache.spark.network.netty +package org.apache.spark.network.netty.client -import io.netty.channel.ChannelInitializer -import io.netty.channel.socket.SocketChannel -import io.netty.handler.codec.string.StringEncoder +import java.util.EventListener -class FileClientChannelInitializer(handler: FileClientHandler) - extends ChannelInitializer[SocketChannel] { +trait BlockClientListener extends EventListener { + + def onFetchSuccess(blockId: String, data: ReferenceCountedBuffer): Unit + + def onFetchFailure(blockId: String, errorMsg: String): Unit - def initChannel(channel: SocketChannel) { - channel.pipeline.addLast("encoder", new StringEncoder).addLast("handler", handler) - } } diff --git a/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClient.scala b/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClient.scala new file mode 100644 index 0000000000000..5aea7ba2f3673 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClient.scala @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty.client + +import java.util.concurrent.TimeoutException + +import io.netty.bootstrap.Bootstrap +import io.netty.buffer.PooledByteBufAllocator +import io.netty.channel.socket.SocketChannel +import io.netty.channel.{ChannelFutureListener, ChannelFuture, ChannelInitializer, ChannelOption} +import io.netty.handler.codec.LengthFieldBasedFrameDecoder +import io.netty.handler.codec.string.StringEncoder +import io.netty.util.CharsetUtil + +import org.apache.spark.Logging + +/** + * Client for fetching data blocks from [[org.apache.spark.network.netty.server.BlockServer]]. + * Use [[BlockFetchingClientFactory]] to instantiate this client. + * + * The constructor blocks until a connection is successfully established. + * + * See [[org.apache.spark.network.netty.server.BlockServer]] for client/server protocol. + * + * Concurrency: thread safe and can be called from multiple threads. + */ +@throws[TimeoutException] +private[spark] +class BlockFetchingClient(factory: BlockFetchingClientFactory, hostname: String, port: Int) + extends Logging { + + private val handler = new BlockFetchingClientHandler + + /** Netty Bootstrap for creating the TCP connection. */ + private val bootstrap: Bootstrap = { + val b = new Bootstrap + b.group(factory.workerGroup) + .channel(factory.socketChannelClass) + // Use pooled buffers to reduce temporary buffer allocation + .option(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT) + // Disable Nagle's Algorithm since we don't want packets to wait + .option(ChannelOption.TCP_NODELAY, java.lang.Boolean.TRUE) + .option(ChannelOption.SO_KEEPALIVE, java.lang.Boolean.TRUE) + .option[Integer](ChannelOption.CONNECT_TIMEOUT_MILLIS, factory.conf.connectTimeoutMs) + + b.handler(new ChannelInitializer[SocketChannel] { + override def initChannel(ch: SocketChannel): Unit = { + ch.pipeline + .addLast("encoder", new StringEncoder(CharsetUtil.UTF_8)) + // maxFrameLength = 2G, lengthFieldOffset = 0, lengthFieldLength = 4 + .addLast("framedLengthDecoder", new LengthFieldBasedFrameDecoder(Int.MaxValue, 0, 4)) + .addLast("handler", handler) + } + }) + b + } + + /** Netty ChannelFuture for the connection. */ + private val cf: ChannelFuture = bootstrap.connect(hostname, port) + if (!cf.awaitUninterruptibly(factory.conf.connectTimeoutMs)) { + throw new TimeoutException( + s"Connecting to $hostname:$port timed out (${factory.conf.connectTimeoutMs} ms)") + } + + /** + * Ask the remote server for a sequence of blocks, and execute the callback. + * + * Note that this is asynchronous and returns immediately. Upstream caller should throttle the + * rate of fetching; otherwise we could run out of memory. + * + * @param blockIds sequence of block ids to fetch. + * @param listener callback to fire on fetch success / failure. + */ + def fetchBlocks(blockIds: Seq[String], listener: BlockClientListener): Unit = { + // It's best to limit the number of "write" calls since it needs to traverse the whole pipeline. + // It's also best to limit the number of "flush" calls since it requires system calls. + // Let's concatenate the string and then call writeAndFlush once. + // This is also why this implementation might be more efficient than multiple, separate + // fetch block calls. + var startTime: Long = 0 + logTrace { + startTime = System.nanoTime + s"Sending request $blockIds to $hostname:$port" + } + + blockIds.foreach { blockId => + handler.addRequest(blockId, listener) + } + + val writeFuture = cf.channel().writeAndFlush(blockIds.mkString("\n") + "\n") + writeFuture.addListener(new ChannelFutureListener { + override def operationComplete(future: ChannelFuture): Unit = { + if (future.isSuccess) { + logTrace { + val timeTaken = (System.nanoTime - startTime).toDouble / 1000000 + s"Sending request $blockIds to $hostname:$port took $timeTaken ms" + } + } else { + // Fail all blocks. + val errorMsg = + s"Failed to send request $blockIds to $hostname:$port: ${future.cause.getMessage}" + logError(errorMsg, future.cause) + blockIds.foreach { blockId => + listener.onFetchFailure(blockId, errorMsg) + handler.removeRequest(blockId) + } + } + } + }) + } + + def waitForClose(): Unit = { + cf.channel().closeFuture().sync() + } + + def close(): Unit = cf.channel().close() +} diff --git a/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClientFactory.scala b/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClientFactory.scala new file mode 100644 index 0000000000000..2b28402c52b49 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClientFactory.scala @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty.client + +import io.netty.channel.epoll.{EpollEventLoopGroup, EpollSocketChannel} +import io.netty.channel.nio.NioEventLoopGroup +import io.netty.channel.oio.OioEventLoopGroup +import io.netty.channel.socket.nio.NioSocketChannel +import io.netty.channel.socket.oio.OioSocketChannel +import io.netty.channel.{EventLoopGroup, Channel} + +import org.apache.spark.SparkConf +import org.apache.spark.network.netty.NettyConfig +import org.apache.spark.util.Utils + +/** + * Factory for creating [[BlockFetchingClient]] by using createClient. This factory reuses + * the worker thread pool for Netty. + * + * Concurrency: createClient is safe to be called from multiple threads concurrently. + */ +private[spark] +class BlockFetchingClientFactory(val conf: NettyConfig) { + + def this(sparkConf: SparkConf) = this(new NettyConfig(sparkConf)) + + /** A thread factory so the threads are named (for debugging). */ + val threadFactory = Utils.namedThreadFactory("spark-shuffle-client") + + /** The following two are instantiated by the [[init]] method, depending ioMode. */ + var socketChannelClass: Class[_ <: Channel] = _ + var workerGroup: EventLoopGroup = _ + + init() + + /** Initialize [[socketChannelClass]] and [[workerGroup]] based on ioMode. */ + private def init(): Unit = { + def initOio(): Unit = { + socketChannelClass = classOf[OioSocketChannel] + workerGroup = new OioEventLoopGroup(0, threadFactory) + } + def initNio(): Unit = { + socketChannelClass = classOf[NioSocketChannel] + workerGroup = new NioEventLoopGroup(0, threadFactory) + } + def initEpoll(): Unit = { + socketChannelClass = classOf[EpollSocketChannel] + workerGroup = new EpollEventLoopGroup(0, threadFactory) + } + + conf.ioMode match { + case "nio" => initNio() + case "oio" => initOio() + case "epoll" => initEpoll() + case "auto" => + // For auto mode, first try epoll (only available on Linux), then nio. + try { + initEpoll() + } catch { + // TODO: Should we log the throwable? But that always happen on non-Linux systems. + // Perhaps the right thing to do is to check whether the system is Linux, and then only + // call initEpoll on Linux. + case e: Throwable => initNio() + } + } + } + + /** + * Create a new BlockFetchingClient connecting to the given remote host / port. + * + * This blocks until a connection is successfully established. + * + * Concurrency: This method is safe to call from multiple threads. + */ + def createClient(remoteHost: String, remotePort: Int): BlockFetchingClient = { + new BlockFetchingClient(this, remoteHost, remotePort) + } + + def stop(): Unit = { + if (workerGroup != null) { + workerGroup.shutdownGracefully() + } + } +} diff --git a/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClientHandler.scala b/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClientHandler.scala new file mode 100644 index 0000000000000..83265b164299d --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClientHandler.scala @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty.client + +import io.netty.buffer.ByteBuf +import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler} + +import org.apache.spark.Logging + + +/** + * Handler that processes server responses. It uses the protocol documented in + * [[org.apache.spark.network.netty.server.BlockServer]]. + * + * Concurrency: thread safe and can be called from multiple threads. + */ +private[client] +class BlockFetchingClientHandler extends SimpleChannelInboundHandler[ByteBuf] with Logging { + + /** Tracks the list of outstanding requests and their listeners on success/failure. */ + private val outstandingRequests = java.util.Collections.synchronizedMap { + new java.util.HashMap[String, BlockClientListener] + } + + def addRequest(blockId: String, listener: BlockClientListener): Unit = { + outstandingRequests.put(blockId, listener) + } + + def removeRequest(blockId: String): Unit = { + outstandingRequests.remove(blockId) + } + + override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = { + val errorMsg = s"Exception in connection from ${ctx.channel.remoteAddress}: ${cause.getMessage}" + logError(errorMsg, cause) + + // Fire the failure callback for all outstanding blocks + outstandingRequests.synchronized { + val iter = outstandingRequests.entrySet().iterator() + while (iter.hasNext) { + val entry = iter.next() + entry.getValue.onFetchFailure(entry.getKey, errorMsg) + } + outstandingRequests.clear() + } + + ctx.close() + } + + override def channelRead0(ctx: ChannelHandlerContext, in: ByteBuf) { + val totalLen = in.readInt() + val blockIdLen = in.readInt() + val blockIdBytes = new Array[Byte](math.abs(blockIdLen)) + in.readBytes(blockIdBytes) + val blockId = new String(blockIdBytes) + val blockSize = totalLen - math.abs(blockIdLen) - 4 + + def server = ctx.channel.remoteAddress.toString + + // blockIdLen is negative when it is an error message. + if (blockIdLen < 0) { + val errorMessageBytes = new Array[Byte](blockSize) + in.readBytes(errorMessageBytes) + val errorMsg = new String(errorMessageBytes) + logTrace(s"Received block $blockId ($blockSize B) with error $errorMsg from $server") + + val listener = outstandingRequests.get(blockId) + if (listener == null) { + // Ignore callback + logWarning(s"Got a response for block $blockId but it is not in our outstanding requests") + } else { + outstandingRequests.remove(blockId) + listener.onFetchFailure(blockId, errorMsg) + } + } else { + logTrace(s"Received block $blockId ($blockSize B) from $server") + + val listener = outstandingRequests.get(blockId) + if (listener == null) { + // Ignore callback + logWarning(s"Got a response for block $blockId but it is not in our outstanding requests") + } else { + outstandingRequests.remove(blockId) + listener.onFetchSuccess(blockId, new ReferenceCountedBuffer(in)) + } + } + } +} diff --git a/core/src/main/scala/org/apache/spark/network/netty/client/LazyInitIterator.scala b/core/src/main/scala/org/apache/spark/network/netty/client/LazyInitIterator.scala new file mode 100644 index 0000000000000..9740ee64d1f2d --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/netty/client/LazyInitIterator.scala @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty.client + +/** + * A simple iterator that lazily initializes the underlying iterator. + * + * The use case is that sometimes we might have many iterators open at the same time, and each of + * the iterator might initialize its own buffer (e.g. decompression buffer, deserialization buffer). + * This could lead to too many buffers open. If this iterator is used, we lazily initialize those + * buffers. + */ +private[spark] +class LazyInitIterator(createIterator: => Iterator[Any]) extends Iterator[Any] { + + lazy val proxy = createIterator + + override def hasNext: Boolean = { + val gotNext = proxy.hasNext + if (!gotNext) { + close() + } + gotNext + } + + override def next(): Any = proxy.next() + + def close(): Unit = Unit +} diff --git a/core/src/main/scala/org/apache/spark/network/netty/client/ReferenceCountedBuffer.scala b/core/src/main/scala/org/apache/spark/network/netty/client/ReferenceCountedBuffer.scala new file mode 100644 index 0000000000000..ea1abf5eccc26 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/netty/client/ReferenceCountedBuffer.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty.client + +import java.io.InputStream +import java.nio.ByteBuffer + +import io.netty.buffer.{ByteBuf, ByteBufInputStream} + + +/** + * A buffer abstraction based on Netty's ByteBuf so we don't expose Netty. + * This is a Scala value class. + * + * The buffer's life cycle is NOT managed by the JVM, and thus requiring explicit declaration of + * reference by the retain method and release method. + */ +private[spark] +class ReferenceCountedBuffer(val underlying: ByteBuf) extends AnyVal { + + /** Return the nio ByteBuffer view of the underlying buffer. */ + def byteBuffer(): ByteBuffer = underlying.nioBuffer + + /** Creates a new input stream that starts from the current position of the buffer. */ + def inputStream(): InputStream = new ByteBufInputStream(underlying) + + /** Increment the reference counter by one. */ + def retain(): Unit = underlying.retain() + + /** Decrement the reference counter by one and release the buffer if the ref count is 0. */ + def release(): Unit = underlying.release() +} diff --git a/core/src/main/scala/org/apache/spark/network/netty/server/BlockHeader.scala b/core/src/main/scala/org/apache/spark/network/netty/server/BlockHeader.scala new file mode 100644 index 0000000000000..162e9cc6828d4 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/netty/server/BlockHeader.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty.server + +/** + * Header describing a block. This is used only in the server pipeline. + * + * [[BlockServerHandler]] creates this, and [[BlockHeaderEncoder]] encodes it. + * + * @param blockSize length of the block content, excluding the length itself. + * If positive, this is the header for a block (not part of the header). + * If negative, this is the header and content for an error message. + * @param blockId block id + * @param error some error message from reading the block + */ +private[server] +class BlockHeader(val blockSize: Int, val blockId: String, val error: Option[String] = None) diff --git a/core/src/main/scala/org/apache/spark/network/netty/server/BlockHeaderEncoder.scala b/core/src/main/scala/org/apache/spark/network/netty/server/BlockHeaderEncoder.scala new file mode 100644 index 0000000000000..8e4dda4ef8595 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/netty/server/BlockHeaderEncoder.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty.server + +import io.netty.buffer.ByteBuf +import io.netty.channel.ChannelHandlerContext +import io.netty.handler.codec.MessageToByteEncoder + +/** + * A simple encoder for BlockHeader. See [[BlockServer]] for the server to client protocol. + */ +private[server] +class BlockHeaderEncoder extends MessageToByteEncoder[BlockHeader] { + override def encode(ctx: ChannelHandlerContext, msg: BlockHeader, out: ByteBuf): Unit = { + // message = message length (4 bytes) + block id length (4 bytes) + block id + block data + // message length = block id length (4 bytes) + size of block id + size of block data + val blockIdBytes = msg.blockId.getBytes + msg.error match { + case Some(errorMsg) => + val errorBytes = errorMsg.getBytes + out.writeInt(4 + blockIdBytes.length + errorBytes.size) + out.writeInt(-blockIdBytes.length) // use negative block id length to represent errors + out.writeBytes(blockIdBytes) // next is blockId itself + out.writeBytes(errorBytes) // error message + case None => + out.writeInt(4 + blockIdBytes.length + msg.blockSize) + out.writeInt(blockIdBytes.length) // First 4 bytes is blockId length + out.writeBytes(blockIdBytes) // next is blockId itself + // msg of size blockSize will be written by ServerHandler + } + } +} diff --git a/core/src/main/scala/org/apache/spark/network/netty/server/BlockServer.scala b/core/src/main/scala/org/apache/spark/network/netty/server/BlockServer.scala new file mode 100644 index 0000000000000..7b2f9a8d4dfd0 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/netty/server/BlockServer.scala @@ -0,0 +1,162 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty.server + +import java.net.InetSocketAddress + +import io.netty.bootstrap.ServerBootstrap +import io.netty.buffer.PooledByteBufAllocator +import io.netty.channel.{ChannelFuture, ChannelInitializer, ChannelOption} +import io.netty.channel.epoll.{EpollEventLoopGroup, EpollServerSocketChannel} +import io.netty.channel.nio.NioEventLoopGroup +import io.netty.channel.oio.OioEventLoopGroup +import io.netty.channel.socket.SocketChannel +import io.netty.channel.socket.nio.NioServerSocketChannel +import io.netty.channel.socket.oio.OioServerSocketChannel +import io.netty.handler.codec.LineBasedFrameDecoder +import io.netty.handler.codec.string.StringDecoder +import io.netty.util.CharsetUtil + +import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.network.netty.NettyConfig +import org.apache.spark.storage.BlockDataProvider +import org.apache.spark.util.Utils + + +/** + * Server for serving Spark data blocks. + * This should be used together with [[org.apache.spark.network.netty.client.BlockFetchingClient]]. + * + * Protocol for requesting blocks (client to server): + * One block id per line, e.g. to request 3 blocks: "block1\nblock2\nblock3\n" + * + * Protocol for sending blocks (server to client): + * frame-length (4 bytes), block-id-length (4 bytes), block-id, block-data. + * + * frame-length should not include the length of itself. + * If block-id-length is negative, then this is an error message rather than block-data. The real + * length is the absolute value of the frame-length. + * + */ +private[spark] +class BlockServer(conf: NettyConfig, dataProvider: BlockDataProvider) extends Logging { + + def this(sparkConf: SparkConf, dataProvider: BlockDataProvider) = { + this(new NettyConfig(sparkConf), dataProvider) + } + + def port: Int = _port + + def hostName: String = _hostName + + private var _port: Int = conf.serverPort + private var _hostName: String = "" + private var bootstrap: ServerBootstrap = _ + private var channelFuture: ChannelFuture = _ + + init() + + /** Initialize the server. */ + private def init(): Unit = { + bootstrap = new ServerBootstrap + val bossThreadFactory = Utils.namedThreadFactory("spark-shuffle-server-boss") + val workerThreadFactory = Utils.namedThreadFactory("spark-shuffle-server-worker") + + // Use only one thread to accept connections, and 2 * num_cores for worker. + def initNio(): Unit = { + val bossGroup = new NioEventLoopGroup(1, bossThreadFactory) + val workerGroup = new NioEventLoopGroup(0, workerThreadFactory) + workerGroup.setIoRatio(conf.ioRatio) + bootstrap.group(bossGroup, workerGroup).channel(classOf[NioServerSocketChannel]) + } + def initOio(): Unit = { + val bossGroup = new OioEventLoopGroup(1, bossThreadFactory) + val workerGroup = new OioEventLoopGroup(0, workerThreadFactory) + bootstrap.group(bossGroup, workerGroup).channel(classOf[OioServerSocketChannel]) + } + def initEpoll(): Unit = { + val bossGroup = new EpollEventLoopGroup(1, bossThreadFactory) + val workerGroup = new EpollEventLoopGroup(0, workerThreadFactory) + workerGroup.setIoRatio(conf.ioRatio) + bootstrap.group(bossGroup, workerGroup).channel(classOf[EpollServerSocketChannel]) + } + + conf.ioMode match { + case "nio" => initNio() + case "oio" => initOio() + case "epoll" => initEpoll() + case "auto" => + // For auto mode, first try epoll (only available on Linux), then nio. + try { + initEpoll() + } catch { + // TODO: Should we log the throwable? But that always happen on non-Linux systems. + // Perhaps the right thing to do is to check whether the system is Linux, and then only + // call initEpoll on Linux. + case e: Throwable => initNio() + } + } + + // Use pooled buffers to reduce temporary buffer allocation + bootstrap.option(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT) + bootstrap.childOption(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT) + + // Various (advanced) user-configured settings. + conf.backLog.foreach { backLog => + bootstrap.option[java.lang.Integer](ChannelOption.SO_BACKLOG, backLog) + } + conf.receiveBuf.foreach { receiveBuf => + bootstrap.option[java.lang.Integer](ChannelOption.SO_RCVBUF, receiveBuf) + } + conf.sendBuf.foreach { sendBuf => + bootstrap.option[java.lang.Integer](ChannelOption.SO_SNDBUF, sendBuf) + } + + bootstrap.childHandler(new ChannelInitializer[SocketChannel] { + override def initChannel(ch: SocketChannel): Unit = { + ch.pipeline + .addLast("frameDecoder", new LineBasedFrameDecoder(1024)) // max block id length 1024 + .addLast("stringDecoder", new StringDecoder(CharsetUtil.UTF_8)) + .addLast("blockHeaderEncoder", new BlockHeaderEncoder) + .addLast("handler", new BlockServerHandler(dataProvider)) + } + }) + + channelFuture = bootstrap.bind(new InetSocketAddress(_port)) + channelFuture.sync() + + val addr = channelFuture.channel.localAddress.asInstanceOf[InetSocketAddress] + _port = addr.getPort + _hostName = addr.getHostName + } + + /** Shutdown the server. */ + def stop(): Unit = { + if (channelFuture != null) { + channelFuture.channel().close().awaitUninterruptibly() + channelFuture = null + } + if (bootstrap != null && bootstrap.group() != null) { + bootstrap.group().shutdownGracefully() + } + if (bootstrap != null && bootstrap.childGroup() != null) { + bootstrap.childGroup().shutdownGracefully() + } + bootstrap = null + } +} diff --git a/core/src/main/scala/org/apache/spark/network/netty/FileServerChannelInitializer.scala b/core/src/main/scala/org/apache/spark/network/netty/server/BlockServerChannelInitializer.scala similarity index 58% rename from core/src/main/scala/org/apache/spark/network/netty/FileServerChannelInitializer.scala rename to core/src/main/scala/org/apache/spark/network/netty/server/BlockServerChannelInitializer.scala index aaa2f913d0269..cc70bd0c5c477 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/FileServerChannelInitializer.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/server/BlockServerChannelInitializer.scala @@ -15,20 +15,26 @@ * limitations under the License. */ -package org.apache.spark.network.netty +package org.apache.spark.network.netty.server import io.netty.channel.ChannelInitializer import io.netty.channel.socket.SocketChannel -import io.netty.handler.codec.{DelimiterBasedFrameDecoder, Delimiters} +import io.netty.handler.codec.LineBasedFrameDecoder import io.netty.handler.codec.string.StringDecoder +import io.netty.util.CharsetUtil +import org.apache.spark.storage.BlockDataProvider -class FileServerChannelInitializer(pResolver: PathResolver) + +/** Channel initializer that sets up the pipeline for the BlockServer. */ +private[netty] +class BlockServerChannelInitializer(dataProvider: BlockDataProvider) extends ChannelInitializer[SocketChannel] { - override def initChannel(channel: SocketChannel): Unit = { - channel.pipeline - .addLast("framer", new DelimiterBasedFrameDecoder(8192, Delimiters.lineDelimiter : _*)) - .addLast("stringDecoder", new StringDecoder) - .addLast("handler", new FileServerHandler(pResolver)) + override def initChannel(ch: SocketChannel): Unit = { + ch.pipeline + .addLast("frameDecoder", new LineBasedFrameDecoder(1024)) // max block id length 1024 + .addLast("stringDecoder", new StringDecoder(CharsetUtil.UTF_8)) + .addLast("blockHeaderEncoder", new BlockHeaderEncoder) + .addLast("handler", new BlockServerHandler(dataProvider)) } } diff --git a/core/src/main/scala/org/apache/spark/network/netty/server/BlockServerHandler.scala b/core/src/main/scala/org/apache/spark/network/netty/server/BlockServerHandler.scala new file mode 100644 index 0000000000000..40dd5e5d1a2ac --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/netty/server/BlockServerHandler.scala @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty.server + +import java.io.FileInputStream +import java.nio.ByteBuffer +import java.nio.channels.FileChannel + +import io.netty.buffer.Unpooled +import io.netty.channel._ + +import org.apache.spark.Logging +import org.apache.spark.storage.{FileSegment, BlockDataProvider} + + +/** + * A handler that processes requests from clients and writes block data back. + * + * The messages should have been processed by a LineBasedFrameDecoder and a StringDecoder first + * so channelRead0 is called once per line (i.e. per block id). + */ +private[server] +class BlockServerHandler(dataProvider: BlockDataProvider) + extends SimpleChannelInboundHandler[String] with Logging { + + override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = { + logError(s"Exception in connection from ${ctx.channel.remoteAddress}", cause) + ctx.close() + } + + override def channelRead0(ctx: ChannelHandlerContext, blockId: String): Unit = { + def client = ctx.channel.remoteAddress.toString + + // A helper function to send error message back to the client. + def respondWithError(error: String): Unit = { + ctx.writeAndFlush(new BlockHeader(-1, blockId, Some(error))).addListener( + new ChannelFutureListener { + override def operationComplete(future: ChannelFuture) { + if (!future.isSuccess) { + // TODO: Maybe log the success case as well. + logError(s"Error sending error back to $client", future.cause) + ctx.close() + } + } + } + ) + } + + def writeFileSegment(segment: FileSegment): Unit = { + // Send error message back if the block is too large. Even though we are capable of sending + // large (2G+) blocks, the receiving end cannot handle it so let's fail fast. + // Once we fixed the receiving end to be able to process large blocks, this should be removed. + // Also make sure we update BlockHeaderEncoder to support length > 2G. + + // See [[BlockHeaderEncoder]] for the way length is encoded. + if (segment.length + blockId.length + 4 > Int.MaxValue) { + respondWithError(s"Block $blockId size ($segment.length) greater than 2G") + return + } + + var fileChannel: FileChannel = null + try { + fileChannel = new FileInputStream(segment.file).getChannel + } catch { + case e: Exception => + logError( + s"Error opening channel for $blockId in ${segment.file} for request from $client", e) + respondWithError(e.getMessage) + } + + // Found the block. Send it back. + if (fileChannel != null) { + // Write the header and block data. In the case of failures, the listener on the block data + // write should close the connection. + ctx.write(new BlockHeader(segment.length.toInt, blockId)) + + val region = new DefaultFileRegion(fileChannel, segment.offset, segment.length) + ctx.writeAndFlush(region).addListener(new ChannelFutureListener { + override def operationComplete(future: ChannelFuture) { + if (future.isSuccess) { + logTrace(s"Sent block $blockId (${segment.length} B) back to $client") + } else { + logError(s"Error sending block $blockId to $client; closing connection", future.cause) + ctx.close() + } + } + }) + } + } + + def writeByteBuffer(buf: ByteBuffer): Unit = { + ctx.write(new BlockHeader(buf.remaining, blockId)) + ctx.writeAndFlush(Unpooled.wrappedBuffer(buf)).addListener(new ChannelFutureListener { + override def operationComplete(future: ChannelFuture) { + if (future.isSuccess) { + logTrace(s"Sent block $blockId (${buf.remaining} B) back to $client") + } else { + logError(s"Error sending block $blockId to $client; closing connection", future.cause) + ctx.close() + } + } + }) + } + + logTrace(s"Received request from $client to fetch block $blockId") + + var blockData: Either[FileSegment, ByteBuffer] = null + + // First make sure we can find the block. If not, send error back to the user. + try { + blockData = dataProvider.getBlockData(blockId) + } catch { + case e: Exception => + logError(s"Error opening block $blockId for request from $client", e) + respondWithError(e.getMessage) + return + } + + blockData match { + case Left(segment) => writeFileSegment(segment) + case Right(buf) => writeByteBuffer(buf) + } + + } // end of channelRead0 +} diff --git a/core/src/main/scala/org/apache/spark/storage/BlockDataProvider.scala b/core/src/main/scala/org/apache/spark/storage/BlockDataProvider.scala new file mode 100644 index 0000000000000..5b6d086630834 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/storage/BlockDataProvider.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import java.nio.ByteBuffer + + +/** + * An interface for providing data for blocks. + * + * getBlockData returns either a FileSegment (for zero-copy send), or a ByteBuffer. + * + * Aside from unit tests, [[BlockManager]] is the main class that implements this. + */ +private[spark] trait BlockDataProvider { + def getBlockData(blockId: String): Either[FileSegment, ByteBuffer] +} diff --git a/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala index 5f44f5f3197fd..ca60ec78b62ee 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala @@ -18,19 +18,17 @@ package org.apache.spark.storage import java.util.concurrent.LinkedBlockingQueue +import org.apache.spark.network.netty.client.{BlockClientListener, LazyInitIterator, ReferenceCountedBuffer} import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashSet import scala.collection.mutable.Queue import scala.util.{Failure, Success} -import io.netty.buffer.ByteBuf - import org.apache.spark.{Logging, SparkException} import org.apache.spark.executor.ShuffleReadMetrics import org.apache.spark.network.BufferMessage import org.apache.spark.network.ConnectionManagerId -import org.apache.spark.network.netty.ShuffleCopier import org.apache.spark.serializer.Serializer import org.apache.spark.util.Utils @@ -54,18 +52,28 @@ trait BlockFetcherIterator extends Iterator[(BlockId, Option[Iterator[Any]])] wi private[storage] object BlockFetcherIterator { - // A request to fetch one or more blocks, complete with their sizes + /** + * A request to fetch blocks from a remote BlockManager. + * @param address remote BlockManager to fetch from. + * @param blocks Sequence of tuple, where the first element is the block id, + * and the second element is the estimated size, used to calculate bytesInFlight. + */ class FetchRequest(val address: BlockManagerId, val blocks: Seq[(BlockId, Long)]) { val size = blocks.map(_._2).sum } - // A result of a fetch. Includes the block ID, size in bytes, and a function to deserialize - // the block (since we want all deserializaton to happen in the calling thread); can also - // represent a fetch failure if size == -1. + /** + * Result of a fetch from a remote block. A failure is represented as size == -1. + * @param blockId block id + * @param size estimated size of the block, used to calculate bytesInFlight. + * Note that this is NOT the exact bytes. + * @param deserialize closure to return the result in the form of an Iterator. + */ class FetchResult(val blockId: BlockId, val size: Long, val deserialize: () => Iterator[Any]) { def failed: Boolean = size == -1 } + // TODO: Refactor this whole thing to make code more reusable. class BasicBlockFetcherIterator( private val blockManager: BlockManager, val blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])], @@ -95,10 +103,10 @@ object BlockFetcherIterator { // Queue of fetch requests to issue; we'll pull requests off this gradually to make sure that // the number of bytes in flight is limited to maxBytesInFlight - private val fetchRequests = new Queue[FetchRequest] + protected val fetchRequests = new Queue[FetchRequest] // Current bytes in flight from our requests - private var bytesInFlight = 0L + protected var bytesInFlight = 0L protected def sendRequest(req: FetchRequest) { logDebug("Sending request for %d blocks (%s) from %s".format( @@ -262,77 +270,58 @@ object BlockFetcherIterator { readMetrics: ShuffleReadMetrics) extends BasicBlockFetcherIterator(blockManager, blocksByAddress, serializer, readMetrics) { - import blockManager._ + override protected def sendRequest(req: FetchRequest) { + logDebug("Sending request for %d blocks (%s) from %s".format( + req.blocks.size, Utils.bytesToString(req.size), req.address.hostPort)) + val cmId = new ConnectionManagerId(req.address.host, req.address.port) - val fetchRequestsSync = new LinkedBlockingQueue[FetchRequest] + bytesInFlight += req.size + val sizeMap = req.blocks.toMap // so we can look up the size of each blockID + + // This could throw a TimeoutException. In that case we will just retry the task. + val client = blockManager.nettyBlockClientFactory.createClient( + cmId.host, req.address.nettyPort) + val blocks = req.blocks.map(_._1.toString) + + client.fetchBlocks( + blocks, + new BlockClientListener { + override def onFetchFailure(blockId: String, errorMsg: String): Unit = { + logError(s"Could not get block(s) from $cmId with error: $errorMsg") + for ((blockId, size) <- req.blocks) { + results.put(new FetchResult(blockId, -1, null)) + } + } - private def startCopiers(numCopiers: Int): List[_ <: Thread] = { - (for ( i <- Range(0,numCopiers) ) yield { - val copier = new Thread { - override def run(){ - try { - while(!isInterrupted && !fetchRequestsSync.isEmpty) { - sendRequest(fetchRequestsSync.take()) + override def onFetchSuccess(blockId: String, data: ReferenceCountedBuffer): Unit = { + // Increment the reference count so the buffer won't be recycled. + // TODO: This could result in memory leaks when the task is stopped due to exception + // before the iterator is exhausted. + data.retain() + val buf = data.byteBuffer() + val blockSize = buf.remaining() + val bid = BlockId(blockId) + + // TODO: remove code duplication between here and BlockManager.dataDeserialization. + results.put(new FetchResult(bid, sizeMap(bid), () => { + def createIterator: Iterator[Any] = { + val stream = blockManager.wrapForCompression(bid, data.inputStream()) + serializer.newInstance().deserializeStream(stream).asIterator } - } catch { - case x: InterruptedException => logInfo("Copier Interrupted") - // case _ => throw new SparkException("Exception Throw in Shuffle Copier") + new LazyInitIterator(createIterator) { + // Release the buffer when we are done traversing it. + override def close(): Unit = data.release() + } + })) + + readMetrics.synchronized { + readMetrics.remoteBytesRead += blockSize + readMetrics.remoteBlocksFetched += 1 } + logDebug("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime)) } } - copier.start - copier - }).toList - } - - // keep this to interrupt the threads when necessary - private def stopCopiers() { - for (copier <- copiers) { - copier.interrupt() - } - } - - override protected def sendRequest(req: FetchRequest) { - - def putResult(blockId: BlockId, blockSize: Long, blockData: ByteBuf) { - val fetchResult = new FetchResult(blockId, blockSize, - () => dataDeserialize(blockId, blockData.nioBuffer, serializer)) - results.put(fetchResult) - } - - logDebug("Sending request for %d blocks (%s) from %s".format( - req.blocks.size, Utils.bytesToString(req.size), req.address.host)) - val cmId = new ConnectionManagerId(req.address.host, req.address.nettyPort) - val cpier = new ShuffleCopier(blockManager.conf) - cpier.getBlocks(cmId, req.blocks, putResult) - logDebug("Sent request for remote blocks " + req.blocks + " from " + req.address.host ) - } - - private var copiers: List[_ <: Thread] = null - - override def initialize() { - // Split Local Remote Blocks and set numBlocksToFetch - val remoteRequests = splitLocalRemoteBlocks() - // Add the remote requests into our queue in a random order - for (request <- Utils.randomize(remoteRequests)) { - fetchRequestsSync.put(request) - } - - copiers = startCopiers(conf.getInt("spark.shuffle.copier.threads", 6)) - logInfo("Started " + fetchRequestsSync.size + " remote fetches in " + - Utils.getUsedTimeMs(startTime)) - - // Get Local Blocks - startTime = System.currentTimeMillis - getLocalBlocks() - logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms") - } - - override def next(): (BlockId, Option[Iterator[Any]]) = { - resultsGotten += 1 - val result = results.take() - // If all the results has been retrieved, copiers will exit automatically - (result.blockId, if (result.failed) None else Some(result.deserialize())) + ) } } // End of NettyBlockFetcherIterator diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index e4c3d58905e7f..c0491fb55e3a4 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -25,17 +25,20 @@ import scala.concurrent.{Await, Future} import scala.concurrent.duration._ import scala.util.Random -import akka.actor.{ActorSystem, Cancellable, Props} +import akka.actor.{ActorSystem, Props} import sun.nio.ch.DirectBuffer import org.apache.spark._ import org.apache.spark.executor._ import org.apache.spark.io.CompressionCodec import org.apache.spark.network._ +import org.apache.spark.network.netty.client.BlockFetchingClientFactory +import org.apache.spark.network.netty.server.BlockServer import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.ShuffleManager import org.apache.spark.util._ + private[spark] sealed trait BlockValues private[spark] case class ByteBufferValues(buffer: ByteBuffer) extends BlockValues private[spark] case class IteratorValues(iterator: Iterator[Any]) extends BlockValues @@ -60,7 +63,7 @@ private[spark] class BlockManager( securityManager: SecurityManager, mapOutputTracker: MapOutputTracker, shuffleManager: ShuffleManager) - extends Logging { + extends BlockDataProvider with Logging { private val port = conf.getInt("spark.blockManager.port", 0) val shuffleBlockManager = new ShuffleBlockManager(this, shuffleManager) @@ -88,13 +91,25 @@ private[spark] class BlockManager( new TachyonStore(this, tachyonBlockManager) } + private val useNetty = conf.getBoolean("spark.shuffle.use.netty", false) + // If we use Netty for shuffle, start a new Netty-based shuffle sender service. - private val nettyPort: Int = { - val useNetty = conf.getBoolean("spark.shuffle.use.netty", false) - val nettyPortConfig = conf.getInt("spark.shuffle.sender.port", 0) - if (useNetty) diskBlockManager.startShuffleBlockSender(nettyPortConfig) else 0 + private[storage] val nettyBlockClientFactory: BlockFetchingClientFactory = { + if (useNetty) new BlockFetchingClientFactory(conf) else null } + private val nettyBlockServer: BlockServer = { + if (useNetty) { + val server = new BlockServer(conf, this) + logInfo(s"Created NettyBlockServer binding to port: ${server.port}") + server + } else { + null + } + } + + private val nettyPort: Int = if (useNetty) nettyBlockServer.port else 0 + val blockManagerId = BlockManagerId( executorId, connectionManager.id.host, connectionManager.id.port, nettyPort) @@ -219,6 +234,20 @@ private[spark] class BlockManager( } } + override def getBlockData(blockId: String): Either[FileSegment, ByteBuffer] = { + val bid = BlockId(blockId) + if (bid.isShuffle) { + Left(diskBlockManager.getBlockLocation(bid)) + } else { + val blockBytesOpt = doGetLocal(bid, asBlockResult = false).asInstanceOf[Option[ByteBuffer]] + if (blockBytesOpt.isDefined) { + Right(blockBytesOpt.get) + } else { + throw new BlockNotFoundException(blockId) + } + } + } + /** * Get the BlockStatus for the block identified by the given ID, if it exists. * NOTE: This is mainly for testing, and it doesn't fetch information from Tachyon. @@ -1064,6 +1093,14 @@ private[spark] class BlockManager( connectionManager.stop() shuffleBlockManager.stop() diskBlockManager.stop() + + if (nettyBlockClientFactory != null) { + nettyBlockClientFactory.stop() + } + if (nettyBlockServer != null) { + nettyBlockServer.stop() + } + actorSystem.stop(slaveActor) blockInfo.clear() memoryStore.clear() diff --git a/core/src/main/scala/org/apache/spark/storage/BlockNotFoundException.scala b/core/src/main/scala/org/apache/spark/storage/BlockNotFoundException.scala new file mode 100644 index 0000000000000..9ef453605f4f1 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/storage/BlockNotFoundException.scala @@ -0,0 +1,21 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + + +class BlockNotFoundException(blockId: String) extends Exception(s"Block $blockId not found") diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala index 4d66ccea211fa..f3da816389581 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala @@ -23,7 +23,7 @@ import java.util.{Date, Random, UUID} import org.apache.spark.{SparkEnv, Logging} import org.apache.spark.executor.ExecutorExitCode -import org.apache.spark.network.netty.{PathResolver, ShuffleSender} +import org.apache.spark.network.netty.PathResolver import org.apache.spark.util.Utils import org.apache.spark.shuffle.sort.SortShuffleManager @@ -52,7 +52,6 @@ private[spark] class DiskBlockManager(shuffleBlockManager: ShuffleBlockManager, System.exit(ExecutorExitCode.DISK_STORE_FAILED_TO_CREATE_DIR) } private val subDirs = Array.fill(localDirs.length)(new Array[File](subDirsPerLocalDir)) - private var shuffleSender : ShuffleSender = null addShutdownHook() @@ -186,15 +185,5 @@ private[spark] class DiskBlockManager(shuffleBlockManager: ShuffleBlockManager, } } } - - if (shuffleSender != null) { - shuffleSender.stop() - } - } - - private[storage] def startShuffleBlockSender(port: Int): Int = { - shuffleSender = new ShuffleSender(port, this) - logInfo(s"Created ShuffleSender binding to port: ${shuffleSender.port}") - shuffleSender.port } } diff --git a/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala new file mode 100644 index 0000000000000..02d0ffc86f58f --- /dev/null +++ b/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala @@ -0,0 +1,161 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty + +import java.io.{RandomAccessFile, File} +import java.nio.ByteBuffer +import java.util.{Collections, HashSet} +import java.util.concurrent.{TimeUnit, Semaphore} + +import scala.collection.JavaConversions._ + +import io.netty.buffer.{ByteBufUtil, Unpooled} + +import org.scalatest.{BeforeAndAfterAll, FunSuite} + +import org.apache.spark.SparkConf +import org.apache.spark.network.netty.client.{BlockClientListener, ReferenceCountedBuffer, BlockFetchingClientFactory} +import org.apache.spark.network.netty.server.BlockServer +import org.apache.spark.storage.{FileSegment, BlockDataProvider} + + +/** + * Test suite that makes sure the server and the client implementations share the same protocol. + */ +class ServerClientIntegrationSuite extends FunSuite with BeforeAndAfterAll { + + val bufSize = 100000 + var buf: ByteBuffer = _ + var testFile: File = _ + var server: BlockServer = _ + var clientFactory: BlockFetchingClientFactory = _ + + val bufferBlockId = "buffer_block" + val fileBlockId = "file_block" + + val fileContent = new Array[Byte](1024) + scala.util.Random.nextBytes(fileContent) + + override def beforeAll() = { + buf = ByteBuffer.allocate(bufSize) + for (i <- 1 to bufSize) { + buf.put(i.toByte) + } + buf.flip() + + testFile = File.createTempFile("netty-test-file", "txt") + val fp = new RandomAccessFile(testFile, "rw") + fp.write(fileContent) + fp.close() + + server = new BlockServer(new SparkConf, new BlockDataProvider { + override def getBlockData(blockId: String): Either[FileSegment, ByteBuffer] = { + if (blockId == bufferBlockId) { + Right(buf) + } else if (blockId == fileBlockId) { + Left(new FileSegment(testFile, 10, testFile.length - 25)) + } else { + throw new Exception("Unknown block id " + blockId) + } + } + }) + + clientFactory = new BlockFetchingClientFactory(new SparkConf) + } + + override def afterAll() = { + server.stop() + clientFactory.stop() + } + + /** A ByteBuf for buffer_block */ + lazy val byteBufferBlockReference = Unpooled.wrappedBuffer(buf) + + /** A ByteBuf for file_block */ + lazy val fileBlockReference = Unpooled.wrappedBuffer(fileContent, 10, fileContent.length - 25) + + def fetchBlocks(blockIds: Seq[String]): (Set[String], Set[ReferenceCountedBuffer], Set[String]) = + { + val client = clientFactory.createClient(server.hostName, server.port) + val sem = new Semaphore(0) + val receivedBlockIds = Collections.synchronizedSet(new HashSet[String]) + val errorBlockIds = Collections.synchronizedSet(new HashSet[String]) + val receivedBuffers = Collections.synchronizedSet(new HashSet[ReferenceCountedBuffer]) + + client.fetchBlocks( + blockIds, + new BlockClientListener { + override def onFetchFailure(blockId: String, errorMsg: String): Unit = { + errorBlockIds.add(blockId) + sem.release() + } + + override def onFetchSuccess(blockId: String, data: ReferenceCountedBuffer): Unit = { + receivedBlockIds.add(blockId) + data.retain() + receivedBuffers.add(data) + sem.release() + } + } + ) + if (!sem.tryAcquire(blockIds.size, 30, TimeUnit.SECONDS)) { + fail("Timeout getting response from the server") + } + client.close() + (receivedBlockIds.toSet, receivedBuffers.toSet, errorBlockIds.toSet) + } + + test("fetch a ByteBuffer block") { + val (blockIds, buffers, failBlockIds) = fetchBlocks(Seq(bufferBlockId)) + assert(blockIds === Set(bufferBlockId)) + assert(buffers.map(_.underlying) === Set(byteBufferBlockReference)) + assert(failBlockIds.isEmpty) + buffers.foreach(_.release()) + } + + test("fetch a FileSegment block via zero-copy send") { + val (blockIds, buffers, failBlockIds) = fetchBlocks(Seq(fileBlockId)) + assert(blockIds === Set(fileBlockId)) + assert(buffers.map(_.underlying) === Set(fileBlockReference)) + assert(failBlockIds.isEmpty) + buffers.foreach(_.release()) + } + + test("fetch a non-existent block") { + val (blockIds, buffers, failBlockIds) = fetchBlocks(Seq("random-block")) + assert(blockIds.isEmpty) + assert(buffers.isEmpty) + assert(failBlockIds === Set("random-block")) + } + + test("fetch both ByteBuffer block and FileSegment block") { + val (blockIds, buffers, failBlockIds) = fetchBlocks(Seq(bufferBlockId, fileBlockId)) + assert(blockIds === Set(bufferBlockId, fileBlockId)) + assert(buffers.map(_.underlying) === Set(byteBufferBlockReference, fileBlockReference)) + assert(failBlockIds.isEmpty) + buffers.foreach(_.release()) + } + + test("fetch both ByteBuffer block and a non-existent block") { + val (blockIds, buffers, failBlockIds) = fetchBlocks(Seq(bufferBlockId, "random-block")) + assert(blockIds === Set(bufferBlockId)) + assert(buffers.map(_.underlying) === Set(byteBufferBlockReference)) + assert(failBlockIds === Set("random-block")) + buffers.foreach(_.release()) + } +} diff --git a/core/src/test/scala/org/apache/spark/network/netty/client/BlockFetchingClientHandlerSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/client/BlockFetchingClientHandlerSuite.scala new file mode 100644 index 0000000000000..903ab09ae4322 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/network/netty/client/BlockFetchingClientHandlerSuite.scala @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty.client + +import java.nio.ByteBuffer + +import io.netty.buffer.Unpooled +import io.netty.channel.embedded.EmbeddedChannel + +import org.scalatest.{PrivateMethodTester, FunSuite} + + +class BlockFetchingClientHandlerSuite extends FunSuite with PrivateMethodTester { + + test("handling block data (successful fetch)") { + val blockId = "test_block" + val blockData = "blahblahblahblahblah" + val totalLength = 4 + blockId.length + blockData.length + + var parsedBlockId: String = "" + var parsedBlockData: String = "" + val handler = new BlockFetchingClientHandler + handler.addRequest(blockId, + new BlockClientListener { + override def onFetchFailure(blockId: String, errorMsg: String): Unit = ??? + override def onFetchSuccess(bid: String, refCntBuf: ReferenceCountedBuffer): Unit = { + parsedBlockId = bid + val bytes = new Array[Byte](refCntBuf.byteBuffer().remaining) + refCntBuf.byteBuffer().get(bytes) + parsedBlockData = new String(bytes) + } + } + ) + + val outstandingRequests = PrivateMethod[java.util.Map[_, _]]('outstandingRequests) + assert(handler.invokePrivate(outstandingRequests()).size === 1) + + val channel = new EmbeddedChannel(handler) + val buf = ByteBuffer.allocate(totalLength + 4) // 4 bytes for the length field itself + buf.putInt(totalLength) + buf.putInt(blockId.length) + buf.put(blockId.getBytes) + buf.put(blockData.getBytes) + buf.flip() + + channel.writeInbound(Unpooled.wrappedBuffer(buf)) + assert(parsedBlockId === blockId) + assert(parsedBlockData === blockData) + + assert(handler.invokePrivate(outstandingRequests()).size === 0) + + channel.close() + } + + test("handling error message (failed fetch)") { + val blockId = "test_block" + val errorMsg = "error erro5r error err4or error3 error6 error erro1r" + val totalLength = 4 + blockId.length + errorMsg.length + + var parsedBlockId: String = "" + var parsedErrorMsg: String = "" + val handler = new BlockFetchingClientHandler + handler.addRequest(blockId, new BlockClientListener { + override def onFetchFailure(bid: String, msg: String) ={ + parsedBlockId = bid + parsedErrorMsg = msg + } + override def onFetchSuccess(bid: String, refCntBuf: ReferenceCountedBuffer) = ??? + }) + + val outstandingRequests = PrivateMethod[java.util.Map[_, _]]('outstandingRequests) + assert(handler.invokePrivate(outstandingRequests()).size === 1) + + val channel = new EmbeddedChannel(handler) + val buf = ByteBuffer.allocate(totalLength + 4) // 4 bytes for the length field itself + buf.putInt(totalLength) + buf.putInt(-blockId.length) + buf.put(blockId.getBytes) + buf.put(errorMsg.getBytes) + buf.flip() + + channel.writeInbound(Unpooled.wrappedBuffer(buf)) + assert(parsedBlockId === blockId) + assert(parsedErrorMsg === errorMsg) + + assert(handler.invokePrivate(outstandingRequests()).size === 0) + + channel.close() + } +} diff --git a/core/src/test/scala/org/apache/spark/network/netty/server/BlockHeaderEncoderSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/server/BlockHeaderEncoderSuite.scala new file mode 100644 index 0000000000000..3ee281cb1350b --- /dev/null +++ b/core/src/test/scala/org/apache/spark/network/netty/server/BlockHeaderEncoderSuite.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty.server + +import io.netty.buffer.ByteBuf +import io.netty.channel.embedded.EmbeddedChannel + +import org.scalatest.FunSuite + + +class BlockHeaderEncoderSuite extends FunSuite { + + test("encode normal block data") { + val blockId = "test_block" + val channel = new EmbeddedChannel(new BlockHeaderEncoder) + channel.writeOutbound(new BlockHeader(17, blockId, None)) + val out = channel.readOutbound().asInstanceOf[ByteBuf] + assert(out.readInt() === 4 + blockId.length + 17) + assert(out.readInt() === blockId.length) + + val blockIdBytes = new Array[Byte](blockId.length) + out.readBytes(blockIdBytes) + assert(new String(blockIdBytes) === blockId) + assert(out.readableBytes() === 0) + + channel.close() + } + + test("encode error message") { + val blockId = "error_block" + val errorMsg = "error encountered" + val channel = new EmbeddedChannel(new BlockHeaderEncoder) + channel.writeOutbound(new BlockHeader(17, blockId, Some(errorMsg))) + val out = channel.readOutbound().asInstanceOf[ByteBuf] + assert(out.readInt() === 4 + blockId.length + errorMsg.length) + assert(out.readInt() === -blockId.length) + + val blockIdBytes = new Array[Byte](blockId.length) + out.readBytes(blockIdBytes) + assert(new String(blockIdBytes) === blockId) + + val errorMsgBytes = new Array[Byte](errorMsg.length) + out.readBytes(errorMsgBytes) + assert(new String(errorMsgBytes) === errorMsg) + assert(out.readableBytes() === 0) + + channel.close() + } +} diff --git a/core/src/test/scala/org/apache/spark/network/netty/server/BlockServerHandlerSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/server/BlockServerHandlerSuite.scala new file mode 100644 index 0000000000000..3239c710f1639 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/network/netty/server/BlockServerHandlerSuite.scala @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty.server + +import java.io.{RandomAccessFile, File} +import java.nio.ByteBuffer + +import io.netty.buffer.{Unpooled, ByteBuf} +import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler, DefaultFileRegion} +import io.netty.channel.embedded.EmbeddedChannel + +import org.scalatest.FunSuite + +import org.apache.spark.storage.{BlockDataProvider, FileSegment} + + +class BlockServerHandlerSuite extends FunSuite { + + test("ByteBuffer block") { + val expectedBlockId = "test_bytebuffer_block" + val buf = ByteBuffer.allocate(10000) + for (i <- 1 to 10000) { + buf.put(i.toByte) + } + buf.flip() + + val channel = new EmbeddedChannel(new BlockServerHandler(new BlockDataProvider { + override def getBlockData(blockId: String): Either[FileSegment, ByteBuffer] = Right(buf) + })) + + channel.writeInbound(expectedBlockId) + assert(channel.outboundMessages().size === 2) + + val out1 = channel.readOutbound().asInstanceOf[BlockHeader] + val out2 = channel.readOutbound().asInstanceOf[ByteBuf] + + assert(out1.blockId === expectedBlockId) + assert(out1.blockSize === buf.remaining) + assert(out1.error === None) + + assert(out2.equals(Unpooled.wrappedBuffer(buf))) + + channel.close() + } + + test("FileSegment block via zero-copy") { + val expectedBlockId = "test_file_block" + + // Create random file data + val fileContent = new Array[Byte](1024) + scala.util.Random.nextBytes(fileContent) + val testFile = File.createTempFile("netty-test-file", "txt") + val fp = new RandomAccessFile(testFile, "rw") + fp.write(fileContent) + fp.close() + + val channel = new EmbeddedChannel(new BlockServerHandler(new BlockDataProvider { + override def getBlockData(blockId: String): Either[FileSegment, ByteBuffer] = { + Left(new FileSegment(testFile, 15, testFile.length - 25)) + } + })) + + channel.writeInbound(expectedBlockId) + assert(channel.outboundMessages().size === 2) + + val out1 = channel.readOutbound().asInstanceOf[BlockHeader] + val out2 = channel.readOutbound().asInstanceOf[DefaultFileRegion] + + assert(out1.blockId === expectedBlockId) + assert(out1.blockSize === testFile.length - 25) + assert(out1.error === None) + + assert(out2.count === testFile.length - 25) + assert(out2.position === 15) + } + + test("pipeline exception propagation") { + val blockServerHandler = new BlockServerHandler(new BlockDataProvider { + override def getBlockData(blockId: String): Either[FileSegment, ByteBuffer] = ??? + }) + val exceptionHandler = new SimpleChannelInboundHandler[String]() { + override def channelRead0(ctx: ChannelHandlerContext, msg: String): Unit = { + throw new Exception("this is an error") + } + } + + val channel = new EmbeddedChannel(exceptionHandler, blockServerHandler) + assert(channel.isOpen) + channel.writeInbound("a message to trigger the error") + assert(!channel.isOpen) + } +} diff --git a/pom.xml b/pom.xml index 9e5217e294681..8c4c4af0eda8e 100644 --- a/pom.xml +++ b/pom.xml @@ -419,7 +419,7 @@ io.netty netty-all - 4.0.17.Final + 4.0.23.Final org.apache.derby From 023ed7c0fe9b491dd8d699532260cc2d1c258ebb Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Tue, 19 Aug 2014 17:41:37 -0700 Subject: [PATCH 198/231] [MLLIB] minor update to word2vec very minor update Ishiihara Author: Xiangrui Meng Closes #2043 from mengxr/minor-w2v and squashes the following commits: be649fd [Xiangrui Meng] remove map because we only need append eccefcc [Xiangrui Meng] minor updates to word2vec (cherry picked from commit 1870dbaa5591883e61b2173d064c1a67e871b0f5) Signed-off-by: Xiangrui Meng --- .../apache/spark/mllib/feature/Word2Vec.scala | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index 1dcaa2cd2e630..c3375ed44fd99 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -30,11 +30,9 @@ import org.apache.spark.SparkContext._ import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.linalg.{Vector, Vectors} -import org.apache.spark.mllib.rdd.RDDFunctions._ import org.apache.spark.rdd._ import org.apache.spark.util.Utils import org.apache.spark.util.random.XORShiftRandom -import org.apache.spark.util.collection.PrimitiveKeyOpenHashMap /** * Entry in vocabulary @@ -285,9 +283,9 @@ class Word2Vec extends Serializable with Logging { val newSentences = sentences.repartition(numPartitions).cache() val initRandom = new XORShiftRandom(seed) - var syn0Global = + val syn0Global = Array.fill[Float](vocabSize * vectorSize)((initRandom.nextFloat() - 0.5f) / vectorSize) - var syn1Global = new Array[Float](vocabSize * vectorSize) + val syn1Global = new Array[Float](vocabSize * vectorSize) var alpha = startingAlpha for (k <- 1 to numIterations) { val partial = newSentences.mapPartitionsWithIndex { case (idx, iter) => @@ -349,21 +347,21 @@ class Word2Vec extends Serializable with Logging { } val syn0Local = model._1 val syn1Local = model._2 - val synOut = new PrimitiveKeyOpenHashMap[Int, Array[Float]](vocabSize * 2) + val synOut = mutable.ListBuffer.empty[(Int, Array[Float])] var index = 0 while(index < vocabSize) { if (syn0Modify(index) != 0) { - synOut.update(index, syn0Local.slice(index * vectorSize, (index + 1) * vectorSize)) + synOut += ((index, syn0Local.slice(index * vectorSize, (index + 1) * vectorSize))) } if (syn1Modify(index) != 0) { - synOut.update(index + vocabSize, - syn1Local.slice(index * vectorSize, (index + 1) * vectorSize)) + synOut += ((index + vocabSize, + syn1Local.slice(index * vectorSize, (index + 1) * vectorSize))) } index += 1 } - Iterator(synOut) + synOut.toIterator } - val synAgg = partial.flatMap(x => x).reduceByKey { case (v1, v2) => + val synAgg = partial.reduceByKey { case (v1, v2) => blas.saxpy(vectorSize, 1.0f, v2, 1, v1, 1) v1 }.collect() From d75464de53b1550d99abf9b085181dc72ce837a7 Mon Sep 17 00:00:00 2001 From: freeman Date: Tue, 19 Aug 2014 18:07:42 -0700 Subject: [PATCH 199/231] [SPARK-3112][MLLIB] Add documentation and example for StreamingLR Added a documentation section on StreamingLR to the ``MLlib - Linear Methods``, including a worked example. mengxr tdas Author: freeman Closes #2047 from freeman-lab/streaming-lr-docs and squashes the following commits: 568d250 [freeman] Tweaks to wording / formatting 05a1139 [freeman] Added documentation and example for StreamingLR (cherry picked from commit c7252b0097cfacd36f17357d195b12a59e503b35) Signed-off-by: Xiangrui Meng --- docs/mllib-linear-methods.md | 75 ++++++++++++++++++++++++++++++++++++ 1 file changed, 75 insertions(+) diff --git a/docs/mllib-linear-methods.md b/docs/mllib-linear-methods.md index e504cd7f0f578..9137f9dc1b692 100644 --- a/docs/mllib-linear-methods.md +++ b/docs/mllib-linear-methods.md @@ -518,6 +518,81 @@ print("Mean Squared Error = " + str(MSE))
    +## Streaming linear regression + +When data arrive in a streaming fashion, it is useful to fit regression models online, +updating the parameters of the model as new data arrives. MLlib currently supports +streaming linear regression using ordinary least squares. The fitting is similar +to that performed offline, except fitting occurs on each batch of data, so that +the model continually updates to reflect the data from the stream. + +### Examples + +The following example demonstrates how to load training and testing data from two different +input streams of text files, parse the streams as labeled points, fit a linear regression model +online to the first stream, and make predictions on the second stream. + +
    + +
    + +First, we import the necessary classes for parsing our input data and creating the model. + +{% highlight scala %} + +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.regression.StreamingLinearRegressionWithSGD + +{% endhighlight %} + +Then we make input streams for training and testing data. We assume a StreamingContext `ssc` +has already been created, see [Spark Streaming Programming Guide](streaming-programming-guide.html#initializing) +for more info. For this example, we use labeled points in training and testing streams, +but in practice you will likely want to use unlabeled vectors for test data. + +{% highlight scala %} + +val trainingData = ssc.textFileStream('/training/data/dir').map(LabeledPoint.parse) +val testData = ssc.textFileStream('/testing/data/dir').map(LabeledPoint.parse) + +{% endhighlight %} + +We create our model by initializing the weights to 0 + +{% highlight scala %} + +val numFeatures = 3 +val model = new StreamingLinearRegressionWithSGD() + .setInitialWeights(Vectors.zeros(numFeatures)) + +{% endhighlight %} + +Now we register the streams for training and testing and start the job. +Printing predictions alongside true labels lets us easily see the result. + +{% highlight scala %} + +model.trainOn(trainingData) +model.predictOnValues(testData.map(lp => (lp.label, lp.features))).print() + +ssc.start() +ssc.awaitTermination() + +{% endhighlight %} + +We can now save text files with data to the training or testing folders. +Each line should be a data point formatted as `(y,[x1,x2,x3])` where `y` is the label +and `x1,x2,x3` are the features. Anytime a text file is placed in `/training/data/dir` +the model will update. Anytime a text file is placed in `/testing/data/dir` you will see predictions. +As you feed more data to the training directory, the predictions +will get better! + +
    + +
    + + ## Implementation (developer) Behind the scene, MLlib implements a simple distributed version of stochastic gradient descent From 607735c16b39ea89a11c2a0db38ae7d3422203d6 Mon Sep 17 00:00:00 2001 From: wangfei Date: Tue, 19 Aug 2014 19:37:02 -0700 Subject: [PATCH 200/231] [SQL] add note of use synchronizedMap in SQLConf Refer to: http://stackoverflow.com/questions/510632/whats-the-difference-between-concurrenthashmap-and-collections-synchronizedmap Collections.synchronizedMap(map) creates a blocking Map which will degrade performance, albeit ensure consistency. So use ConcurrentHashMap(a more effective thread-safe hashmap) instead. also update HiveQuerySuite to fix test error when changed to ConcurrentHashMap. Author: wangfei Author: scwf Closes #1996 from scwf/sqlconf and squashes the following commits: 93bc0c5 [wangfei] revert change of HiveQuerySuite 0cc05dd [wangfei] add note for use synchronizedMap 3c224d31 [scwf] fix formate a7bcb98 [scwf] use ConcurrentHashMap in sql conf, intead synchronizedMap (cherry picked from commit 0e3ab94d413fd70fff748fded42ab5e2ebd66fcc) Signed-off-by: Reynold Xin --- sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 4f2adb006fbc7..5cc41a83cc792 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -54,6 +54,7 @@ private[spark] object SQLConf { trait SQLConf { import SQLConf._ + /** Only low degree of contention is expected for conf, thus NOT using ConcurrentHashMap. */ @transient protected[spark] val settings = java.util.Collections.synchronizedMap( new java.util.HashMap[String, String]()) From 148e45b91aa4efcc0a7e5b28badff22887a92805 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Tue, 19 Aug 2014 21:01:23 -0700 Subject: [PATCH 201/231] [SPARK-3130][MLLIB] detect negative values in naive Bayes because NB treats feature values as term frequencies. jkbradley Author: Xiangrui Meng Closes #2038 from mengxr/nb-neg and squashes the following commits: 52c37c3 [Xiangrui Meng] address comments 65f892d [Xiangrui Meng] detect negative values in nb (cherry picked from commit 068b6fe6a10eb1c6b2102d88832203267f030e85) Signed-off-by: Xiangrui Meng --- docs/mllib-naive-bayes.md | 3 +- .../mllib/classification/NaiveBayes.scala | 28 +++++++++++++++---- .../classification/NaiveBayesSuite.scala | 28 +++++++++++++++++++ 3 files changed, 53 insertions(+), 6 deletions(-) diff --git a/docs/mllib-naive-bayes.md b/docs/mllib-naive-bayes.md index 86d94aebd9442..7f9d4c6563944 100644 --- a/docs/mllib-naive-bayes.md +++ b/docs/mllib-naive-bayes.md @@ -17,7 +17,8 @@ Bayes](http://en.wikipedia.org/wiki/Naive_Bayes_classifier#Multinomial_naive_Bay which is typically used for [document classification](http://nlp.stanford.edu/IR-book/html/htmledition/naive-bayes-text-classification-1.html). Within that context, each observation is a document and each -feature represents a term whose value is the frequency of the term. +feature represents a term whose value is the frequency of the term. +Feature values must be nonnegative to represent term frequencies. [Additive smoothing](http://en.wikipedia.org/wiki/Lidstone_smoothing) can be used by setting the parameter $\lambda$ (default to $1.0$). For document classification, the input feature vectors are usually sparse, and sparse vectors should be supplied as input to take advantage of diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala index 6c7be0a4f1dcb..8c8e4a161aa5b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -19,9 +19,9 @@ package org.apache.spark.mllib.classification import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum} -import org.apache.spark.Logging +import org.apache.spark.{SparkException, Logging} import org.apache.spark.SparkContext._ -import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.rdd.RDD @@ -73,7 +73,7 @@ class NaiveBayesModel private[mllib] ( * This is the Multinomial NB ([[http://tinyurl.com/lsdw6p]]) which can handle all kinds of * discrete data. For example, by converting documents into TF-IDF vectors, it can be used for * document classification. By making every vector a 0-1 vector, it can also be used as - * Bernoulli NB ([[http://tinyurl.com/p7c96j6]]). + * Bernoulli NB ([[http://tinyurl.com/p7c96j6]]). The input feature values must be nonnegative. */ class NaiveBayes private (private var lambda: Double) extends Serializable with Logging { @@ -91,12 +91,30 @@ class NaiveBayes private (private var lambda: Double) extends Serializable with * @param data RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. */ def run(data: RDD[LabeledPoint]) = { + val requireNonnegativeValues: Vector => Unit = (v: Vector) => { + val values = v match { + case sv: SparseVector => + sv.values + case dv: DenseVector => + dv.values + } + if (!values.forall(_ >= 0.0)) { + throw new SparkException(s"Naive Bayes requires nonnegative feature values but found $v.") + } + } + // Aggregates term frequencies per label. // TODO: Calling combineByKey and collect creates two stages, we can implement something // TODO: similar to reduceByKeyLocally to save one stage. val aggregated = data.map(p => (p.label, p.features)).combineByKey[(Long, BDV[Double])]( - createCombiner = (v: Vector) => (1L, v.toBreeze.toDenseVector), - mergeValue = (c: (Long, BDV[Double]), v: Vector) => (c._1 + 1L, c._2 += v.toBreeze), + createCombiner = (v: Vector) => { + requireNonnegativeValues(v) + (1L, v.toBreeze.toDenseVector) + }, + mergeValue = (c: (Long, BDV[Double]), v: Vector) => { + requireNonnegativeValues(v) + (c._1 + 1L, c._2 += v.toBreeze) + }, mergeCombiners = (c1: (Long, BDV[Double]), c2: (Long, BDV[Double])) => (c1._1 + c2._1, c1._2 += c2._2) ).collect() diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala index 06cdd04f5fdae..80989bc074e84 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala @@ -21,6 +21,7 @@ import scala.util.Random import org.scalatest.FunSuite +import org.apache.spark.SparkException import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.{LocalClusterSparkContext, LocalSparkContext} @@ -95,6 +96,33 @@ class NaiveBayesSuite extends FunSuite with LocalSparkContext { // Test prediction on Array. validatePrediction(validationData.map(row => model.predict(row.features)), validationData) } + + test("detect negative values") { + val dense = Seq( + LabeledPoint(1.0, Vectors.dense(1.0)), + LabeledPoint(0.0, Vectors.dense(-1.0)), + LabeledPoint(1.0, Vectors.dense(1.0)), + LabeledPoint(1.0, Vectors.dense(0.0))) + intercept[SparkException] { + NaiveBayes.train(sc.makeRDD(dense, 2)) + } + val sparse = Seq( + LabeledPoint(1.0, Vectors.sparse(1, Array(0), Array(1.0))), + LabeledPoint(0.0, Vectors.sparse(1, Array(0), Array(-1.0))), + LabeledPoint(1.0, Vectors.sparse(1, Array(0), Array(1.0))), + LabeledPoint(1.0, Vectors.sparse(1, Array.empty, Array.empty))) + intercept[SparkException] { + NaiveBayes.train(sc.makeRDD(sparse, 2)) + } + val nan = Seq( + LabeledPoint(1.0, Vectors.sparse(1, Array(0), Array(1.0))), + LabeledPoint(0.0, Vectors.sparse(1, Array(0), Array(Double.NaN))), + LabeledPoint(1.0, Vectors.sparse(1, Array(0), Array(1.0))), + LabeledPoint(1.0, Vectors.sparse(1, Array.empty, Array.empty))) + intercept[SparkException] { + NaiveBayes.train(sc.makeRDD(nan, 2)) + } + } } class NaiveBayesClusterSuite extends FunSuite with LocalClusterSparkContext { From d5db95baec62d911c7611f28535f0440440226cb Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Tue, 19 Aug 2014 22:05:29 -0700 Subject: [PATCH 202/231] [HOTFIX][Streaming][MLlib] use temp folder for checkpoint or Jenkins will complain about no Apache header in checkpoint files. tdas rxin Author: Xiangrui Meng Closes #2046 from mengxr/tmp-checkpoint and squashes the following commits: 0d3ec73 [Xiangrui Meng] remove ssc.stop 9797843 [Xiangrui Meng] change checkpointDir to lazy val 89964ab [Xiangrui Meng] use temp folder for checkpoint (cherry picked from commit fce5c0fb6384f3a142a4155525a5d62640725150) Signed-off-by: Xiangrui Meng --- .../StreamingLinearRegressionSuite.scala | 6 ------ .../apache/spark/streaming/TestSuiteBase.scala | 17 +++++++++++------ 2 files changed, 11 insertions(+), 12 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala index 28489410f8225..03b71301e9ab1 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala @@ -49,7 +49,6 @@ class StreamingLinearRegressionSuite extends FunSuite with TestSuiteBase { // Test if we can accurately learn Y = 10*X1 + 10*X2 on streaming data test("parameter accuracy") { - // create model val model = new StreamingLinearRegressionWithSGD() .setInitialWeights(Vectors.dense(0.0, 0.0)) @@ -82,7 +81,6 @@ class StreamingLinearRegressionSuite extends FunSuite with TestSuiteBase { // Test that parameter estimates improve when learning Y = 10*X1 on streaming data test("parameter convergence") { - // create model val model = new StreamingLinearRegressionWithSGD() .setInitialWeights(Vectors.dense(0.0)) @@ -113,12 +111,10 @@ class StreamingLinearRegressionSuite extends FunSuite with TestSuiteBase { assert(deltas.forall(x => (x._1 - x._2) <= 0.1)) // check that error shrunk on at least 2 batches assert(deltas.map(x => if ((x._1 - x._2) < 0) 1 else 0).sum > 1) - } // Test predictions on a stream test("predictions") { - // create model initialized with true weights val model = new StreamingLinearRegressionWithSGD() .setInitialWeights(Vectors.dense(10.0, 10.0)) @@ -142,7 +138,5 @@ class StreamingLinearRegressionSuite extends FunSuite with TestSuiteBase { // compute the mean absolute error and check that it's always less than 0.1 val errors = output.map(batch => batch.map(p => math.abs(p._1 - p._2)).sum / nPoints) assert(errors.forall(x => x <= 0.1)) - } - } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala index f095da9cb55d3..759baacaa4308 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala @@ -17,18 +17,18 @@ package org.apache.spark.streaming -import org.apache.spark.streaming.dstream.{DStream, InputDStream, ForEachDStream} -import org.apache.spark.streaming.util.ManualClock +import java.io.{ObjectInputStream, IOException} import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.SynchronizedBuffer import scala.reflect.ClassTag -import java.io.{ObjectInputStream, IOException} - import org.scalatest.{BeforeAndAfter, FunSuite} +import com.google.common.io.Files -import org.apache.spark.{SparkContext, SparkConf, Logging} +import org.apache.spark.streaming.dstream.{DStream, InputDStream, ForEachDStream} +import org.apache.spark.streaming.util.ManualClock +import org.apache.spark.{SparkConf, Logging} import org.apache.spark.rdd.RDD /** @@ -119,7 +119,12 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging { def batchDuration = Seconds(1) // Directory where the checkpoint data will be saved - def checkpointDir = "checkpoint" + lazy val checkpointDir = { + val dir = Files.createTempDir() + logDebug(s"checkpointDir: $dir") + dir.deleteOnExit() + dir.toString + } // Number of partitions of the input parallel collections created for testing def numInputPartitions = 2 From 08c9973da01620c3592eac46d2437b18c4d5cba7 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 19 Aug 2014 22:11:13 -0700 Subject: [PATCH 203/231] [SPARK-3119] Re-implementation of TorrentBroadcast. This is a re-implementation of TorrentBroadcast, with the following changes: 1. Removes most of the mutable, transient state from TorrentBroadcast (e.g. totalBytes, num of blocks fetched). 2. Removes TorrentInfo and TorrentBlock 3. Replaces the BlockManager.getSingle call in readObject with a getLocal, resuling in one less RPC call to the BlockManagerMasterActor to find the location of the block. 4. Removes the metadata block, resulting in one less block to fetch. 5. Removes an extra memory copy for deserialization (by using Java's SequenceInputStream). Basically for a regular broadcasted object with only one block, the number of RPC calls goes from 5+1 to 2+1). Old TorrentBroadcast for object of a single block: 1 RPC to ask for location of the broadcast variable 1 RPC to ask for location of the metadata block 1 RPC to fetch the metadata block 1 RPC to ask for location of the first data block 1 RPC to fetch the first data block 1 RPC to tell the driver we put the first data block in i.e. 5 + 1 New TorrentBroadcast for object of a single block: 1 RPC to ask for location of the first data block 1 RPC to get the first data block 1 RPC to tell the driver we put the first data block in i.e. 2 + 1 Author: Reynold Xin Closes #2030 from rxin/torrentBroadcast and squashes the following commits: 5bacb9d [Reynold Xin] Always add the object to driver's block manager. 0d8ed5b [Reynold Xin] Added getBytes to BlockManager and uses that in TorrentBroadcast. 2d6a5fb [Reynold Xin] Use putBytes/getRemoteBytes throughout. 3670f00 [Reynold Xin] Code review feedback. c1185cd [Reynold Xin] [SPARK-3119] Re-implementation of TorrentBroadcast. (cherry picked from commit 8adfbc2b6b5b647e450d30f89c141f935b6aa94b) Signed-off-by: Reynold Xin --- .../spark/broadcast/BroadcastFactory.scala | 11 + .../spark/broadcast/TorrentBroadcast.scala | 282 +++++++----------- .../spark/broadcast/BroadcastSuite.scala | 128 ++++---- 3 files changed, 181 insertions(+), 240 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala index a8c827030a1ef..6a187b40628a2 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala @@ -32,8 +32,19 @@ import org.apache.spark.annotation.DeveloperApi */ @DeveloperApi trait BroadcastFactory { + def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager): Unit + + /** + * Creates a new broadcast variable. + * + * @param value value to broadcast + * @param isLocal whether we are in local mode (single JVM process) + * @param id unique id representing this broadcast variable + */ def newBroadcast[T: ClassTag](value: T, isLocal: Boolean, id: Long): Broadcast[T] + def unbroadcast(id: Long, removeFromDriver: Boolean, blocking: Boolean): Unit + def stop(): Unit } diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index d8be649f96e5f..6173fd3a69fc7 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -18,50 +18,116 @@ package org.apache.spark.broadcast import java.io._ +import java.nio.ByteBuffer +import scala.collection.JavaConversions.asJavaEnumeration import scala.reflect.ClassTag import scala.util.Random import org.apache.spark.{Logging, SparkConf, SparkEnv, SparkException} import org.apache.spark.io.CompressionCodec import org.apache.spark.storage.{BroadcastBlockId, StorageLevel} +import org.apache.spark.util.ByteBufferInputStream /** - * A [[org.apache.spark.broadcast.Broadcast]] implementation that uses a BitTorrent-like - * protocol to do a distributed transfer of the broadcasted data to the executors. - * The mechanism is as follows. The driver divides the serializes the broadcasted data, - * divides it into smaller chunks, and stores them in the BlockManager of the driver. - * These chunks are reported to the BlockManagerMaster so that all the executors can - * learn the location of those chunks. The first time the broadcast variable (sent as - * part of task) is deserialized at a executor, all the chunks are fetched using - * the BlockManager. When all the chunks are fetched (initially from the driver's - * BlockManager), they are combined and deserialized to recreate the broadcasted data. - * However, the chunks are also stored in the BlockManager and reported to the - * BlockManagerMaster. As more executors fetch the chunks, BlockManagerMaster learns - * multiple locations for each chunk. Hence, subsequent fetches of each chunk will be - * made to other executors who already have those chunks, resulting in a distributed - * fetching. This prevents the driver from being the bottleneck in sending out multiple - * copies of the broadcast data (one per executor) as done by the - * [[org.apache.spark.broadcast.HttpBroadcast]]. + * A BitTorrent-like implementation of [[org.apache.spark.broadcast.Broadcast]]. + * + * The mechanism is as follows: + * + * The driver divides the serialized object into small chunks and + * stores those chunks in the BlockManager of the driver. + * + * On each executor, the executor first attempts to fetch the object from its BlockManager. If + * it does not exist, it then uses remote fetches to fetch the small chunks from the driver and/or + * other executors if available. Once it gets the chunks, it puts the chunks in its own + * BlockManager, ready for other executors to fetch from. + * + * This prevents the driver from being the bottleneck in sending out multiple copies of the + * broadcast data (one per executor) as done by the [[org.apache.spark.broadcast.HttpBroadcast]]. + * + * @param obj object to broadcast + * @param isLocal whether Spark is running in local mode (single JVM process). + * @param id A unique identifier for the broadcast variable. */ private[spark] class TorrentBroadcast[T: ClassTag]( - @transient var value_ : T, isLocal: Boolean, id: Long) + obj : T, + @transient private val isLocal: Boolean, + id: Long) extends Broadcast[T](id) with Logging with Serializable { - override protected def getValue() = value_ + /** + * Value of the broadcast object. On driver, this is set directly by the constructor. + * On executors, this is reconstructed by [[readObject]], which builds this value by reading + * blocks from the driver and/or other executors. + */ + @transient private var _value: T = obj private val broadcastId = BroadcastBlockId(id) - SparkEnv.get.blockManager.putSingle( - broadcastId, value_, StorageLevel.MEMORY_AND_DISK, tellMaster = false) + /** Total number of blocks this broadcast variable contains. */ + private val numBlocks: Int = writeBlocks() + + override protected def getValue() = _value + + /** + * Divide the object into multiple blocks and put those blocks in the block manager. + * + * @return number of blocks this broadcast variable is divided into + */ + private def writeBlocks(): Int = { + // For local mode, just put the object in the BlockManager so we can find it later. + SparkEnv.get.blockManager.putSingle( + broadcastId, _value, StorageLevel.MEMORY_AND_DISK, tellMaster = false) + + if (!isLocal) { + val blocks = TorrentBroadcast.blockifyObject(_value) + blocks.zipWithIndex.foreach { case (block, i) => + SparkEnv.get.blockManager.putBytes( + BroadcastBlockId(id, "piece" + i), + block, + StorageLevel.MEMORY_AND_DISK_SER, + tellMaster = true) + } + blocks.length + } else { + 0 + } + } + + /** Fetch torrent blocks from the driver and/or other executors. */ + private def readBlocks(): Array[ByteBuffer] = { + // Fetch chunks of data. Note that all these chunks are stored in the BlockManager and reported + // to the driver, so other executors can pull these chunks from this executor as well. + val blocks = new Array[ByteBuffer](numBlocks) + val bm = SparkEnv.get.blockManager - @transient private var arrayOfBlocks: Array[TorrentBlock] = null - @transient private var totalBlocks = -1 - @transient private var totalBytes = -1 - @transient private var hasBlocks = 0 + for (pid <- Random.shuffle(Seq.range(0, numBlocks))) { + val pieceId = BroadcastBlockId(id, "piece" + pid) - if (!isLocal) { - sendBroadcast() + // First try getLocalBytes because there is a chance that previous attempts to fetch the + // broadcast blocks have already fetched some of the blocks. In that case, some blocks + // would be available locally (on this executor). + var blockOpt = bm.getLocalBytes(pieceId) + if (!blockOpt.isDefined) { + blockOpt = bm.getRemoteBytes(pieceId) + blockOpt match { + case Some(block) => + // If we found the block from remote executors/driver's BlockManager, put the block + // in this executor's BlockManager. + SparkEnv.get.blockManager.putBytes( + pieceId, + block, + StorageLevel.MEMORY_AND_DISK_SER, + tellMaster = true) + + case None => + throw new SparkException("Failed to get " + pieceId + " of " + broadcastId) + } + } + // If we get here, the option is defined. + blocks(pid) = blockOpt.get + } + blocks } /** @@ -79,26 +145,6 @@ private[spark] class TorrentBroadcast[T: ClassTag]( TorrentBroadcast.unpersist(id, removeFromDriver = true, blocking) } - private def sendBroadcast() { - val tInfo = TorrentBroadcast.blockifyObject(value_) - totalBlocks = tInfo.totalBlocks - totalBytes = tInfo.totalBytes - hasBlocks = tInfo.totalBlocks - - // Store meta-info - val metaId = BroadcastBlockId(id, "meta") - val metaInfo = TorrentInfo(null, totalBlocks, totalBytes) - SparkEnv.get.blockManager.putSingle( - metaId, metaInfo, StorageLevel.MEMORY_AND_DISK, tellMaster = true) - - // Store individual pieces - for (i <- 0 until totalBlocks) { - val pieceId = BroadcastBlockId(id, "piece" + i) - SparkEnv.get.blockManager.putSingle( - pieceId, tInfo.arrayOfBlocks(i), StorageLevel.MEMORY_AND_DISK, tellMaster = true) - } - } - /** Used by the JVM when serializing this object. */ private def writeObject(out: ObjectOutputStream) { assertValid() @@ -109,99 +155,30 @@ private[spark] class TorrentBroadcast[T: ClassTag]( private def readObject(in: ObjectInputStream) { in.defaultReadObject() TorrentBroadcast.synchronized { - SparkEnv.get.blockManager.getSingle(broadcastId) match { + SparkEnv.get.blockManager.getLocal(broadcastId).map(_.data.next()) match { case Some(x) => - value_ = x.asInstanceOf[T] + _value = x.asInstanceOf[T] case None => - val start = System.nanoTime logInfo("Started reading broadcast variable " + id) - - // Initialize @transient variables that will receive garbage values from the master. - resetWorkerVariables() - - if (receiveBroadcast()) { - value_ = TorrentBroadcast.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks) - - /* Store the merged copy in cache so that the next worker doesn't need to rebuild it. - * This creates a trade-off between memory usage and latency. Storing copy doubles - * the memory footprint; not storing doubles deserialization cost. Also, - * this does not need to be reported to BlockManagerMaster since other executors - * does not need to access this block (they only need to fetch the chunks, - * which are reported). - */ - SparkEnv.get.blockManager.putSingle( - broadcastId, value_, StorageLevel.MEMORY_AND_DISK, tellMaster = false) - - // Remove arrayOfBlocks from memory once value_ is on local cache - resetWorkerVariables() - } else { - logError("Reading broadcast variable " + id + " failed") - } - - val time = (System.nanoTime - start) / 1e9 + val start = System.nanoTime() + val blocks = readBlocks() + val time = (System.nanoTime() - start) / 1e9 logInfo("Reading broadcast variable " + id + " took " + time + " s") - } - } - } - - private def resetWorkerVariables() { - arrayOfBlocks = null - totalBytes = -1 - totalBlocks = -1 - hasBlocks = 0 - } - - private def receiveBroadcast(): Boolean = { - // Receive meta-info about the size of broadcast data, - // the number of chunks it is divided into, etc. - val metaId = BroadcastBlockId(id, "meta") - var attemptId = 10 - while (attemptId > 0 && totalBlocks == -1) { - SparkEnv.get.blockManager.getSingle(metaId) match { - case Some(x) => - val tInfo = x.asInstanceOf[TorrentInfo] - totalBlocks = tInfo.totalBlocks - totalBytes = tInfo.totalBytes - arrayOfBlocks = new Array[TorrentBlock](totalBlocks) - hasBlocks = 0 - case None => - Thread.sleep(500) - } - attemptId -= 1 - } - - if (totalBlocks == -1) { - return false - } - - /* - * Fetch actual chunks of data. Note that all these chunks are stored in - * the BlockManager and reported to the master, so that other executors - * can find out and pull the chunks from this executor. - */ - val recvOrder = new Random().shuffle(Array.iterate(0, totalBlocks)(_ + 1).toList) - for (pid <- recvOrder) { - val pieceId = BroadcastBlockId(id, "piece" + pid) - SparkEnv.get.blockManager.getSingle(pieceId) match { - case Some(x) => - arrayOfBlocks(pid) = x.asInstanceOf[TorrentBlock] - hasBlocks += 1 + _value = TorrentBroadcast.unBlockifyObject[T](blocks) + // Store the merged copy in BlockManager so other tasks on this executor don't + // need to re-fetch it. SparkEnv.get.blockManager.putSingle( - pieceId, arrayOfBlocks(pid), StorageLevel.MEMORY_AND_DISK, tellMaster = true) - - case None => - throw new SparkException("Failed to get " + pieceId + " of " + broadcastId) + broadcastId, _value, StorageLevel.MEMORY_AND_DISK, tellMaster = false) } } - - hasBlocks == totalBlocks } - } -private[broadcast] object TorrentBroadcast extends Logging { + +private object TorrentBroadcast extends Logging { + /** Size of each block. Default value is 4MB. */ private lazy val BLOCK_SIZE = conf.getInt("spark.broadcast.blockSize", 4096) * 1024 private var initialized = false private var conf: SparkConf = null @@ -223,7 +200,9 @@ private[broadcast] object TorrentBroadcast extends Logging { initialized = false } - def blockifyObject[T: ClassTag](obj: T): TorrentInfo = { + def blockifyObject[T: ClassTag](obj: T): Array[ByteBuffer] = { + // TODO: Create a special ByteArrayOutputStream that splits the output directly into chunks + // so we don't need to do the extra memory copy. val bos = new ByteArrayOutputStream() val out: OutputStream = if (compress) compressionCodec.compressedOutputStream(bos) else bos val ser = SparkEnv.get.serializer.newInstance() @@ -231,44 +210,27 @@ private[broadcast] object TorrentBroadcast extends Logging { serOut.writeObject[T](obj).close() val byteArray = bos.toByteArray val bais = new ByteArrayInputStream(byteArray) + val numBlocks = math.ceil(byteArray.length.toDouble / BLOCK_SIZE).toInt + val blocks = new Array[ByteBuffer](numBlocks) - var blockNum = byteArray.length / BLOCK_SIZE - if (byteArray.length % BLOCK_SIZE != 0) { - blockNum += 1 - } - - val blocks = new Array[TorrentBlock](blockNum) var blockId = 0 - for (i <- 0 until (byteArray.length, BLOCK_SIZE)) { val thisBlockSize = math.min(BLOCK_SIZE, byteArray.length - i) val tempByteArray = new Array[Byte](thisBlockSize) bais.read(tempByteArray, 0, thisBlockSize) - blocks(blockId) = new TorrentBlock(blockId, tempByteArray) + blocks(blockId) = ByteBuffer.wrap(tempByteArray) blockId += 1 } bais.close() - - val info = TorrentInfo(blocks, blockNum, byteArray.length) - info.hasBlocks = blockNum - info + blocks } - def unBlockifyObject[T: ClassTag]( - arrayOfBlocks: Array[TorrentBlock], - totalBytes: Int, - totalBlocks: Int): T = { - val retByteArray = new Array[Byte](totalBytes) - for (i <- 0 until totalBlocks) { - System.arraycopy(arrayOfBlocks(i).byteArray, 0, retByteArray, - i * BLOCK_SIZE, arrayOfBlocks(i).byteArray.length) - } + def unBlockifyObject[T: ClassTag](blocks: Array[ByteBuffer]): T = { + val is = new SequenceInputStream( + asJavaEnumeration(blocks.iterator.map(block => new ByteBufferInputStream(block)))) + val in: InputStream = if (compress) compressionCodec.compressedInputStream(is) else is - val in: InputStream = { - val arrIn = new ByteArrayInputStream(retByteArray) - if (compress) compressionCodec.compressedInputStream(arrIn) else arrIn - } val ser = SparkEnv.get.serializer.newInstance() val serIn = ser.deserializeStream(in) val obj = serIn.readObject[T]() @@ -284,17 +246,3 @@ private[broadcast] object TorrentBroadcast extends Logging { SparkEnv.get.blockManager.master.removeBroadcast(id, removeFromDriver, blocking) } } - -private[broadcast] case class TorrentBlock( - blockID: Int, - byteArray: Array[Byte]) - extends Serializable - -private[broadcast] case class TorrentInfo( - @transient arrayOfBlocks: Array[TorrentBlock], - totalBlocks: Int, - totalBytes: Int) - extends Serializable { - - @transient var hasBlocks = 0 -} diff --git a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala index 17c64455b2429..978a6ded80829 100644 --- a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala +++ b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala @@ -17,10 +17,12 @@ package org.apache.spark.broadcast -import org.apache.spark.storage.{BroadcastBlockId, _} -import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkException} import org.scalatest.FunSuite +import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkException} +import org.apache.spark.storage._ + + class BroadcastSuite extends FunSuite with LocalSparkContext { private val httpConf = broadcastConf("HttpBroadcastFactory") @@ -124,12 +126,10 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { private def testUnpersistHttpBroadcast(distributed: Boolean, removeFromDriver: Boolean) { val numSlaves = if (distributed) 2 else 0 - def getBlockIds(id: Long) = Seq[BroadcastBlockId](BroadcastBlockId(id)) - // Verify that the broadcast file is created, and blocks are persisted only on the driver - def afterCreation(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) { - assert(blockIds.size === 1) - val statuses = bmm.getBlockStatus(blockIds.head, askSlaves = true) + def afterCreation(broadcastId: Long, bmm: BlockManagerMaster) { + val blockId = BroadcastBlockId(broadcastId) + val statuses = bmm.getBlockStatus(blockId, askSlaves = true) assert(statuses.size === 1) statuses.head match { case (bm, status) => assert(bm.executorId === "", "Block should only be on the driver") @@ -139,14 +139,14 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { } if (distributed) { // this file is only generated in distributed mode - assert(HttpBroadcast.getFile(blockIds.head.broadcastId).exists, "Broadcast file not found!") + assert(HttpBroadcast.getFile(blockId.broadcastId).exists, "Broadcast file not found!") } } // Verify that blocks are persisted in both the executors and the driver - def afterUsingBroadcast(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) { - assert(blockIds.size === 1) - val statuses = bmm.getBlockStatus(blockIds.head, askSlaves = true) + def afterUsingBroadcast(broadcastId: Long, bmm: BlockManagerMaster) { + val blockId = BroadcastBlockId(broadcastId) + val statuses = bmm.getBlockStatus(blockId, askSlaves = true) assert(statuses.size === numSlaves + 1) statuses.foreach { case (_, status) => assert(status.storageLevel === StorageLevel.MEMORY_AND_DISK) @@ -157,21 +157,21 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { // Verify that blocks are unpersisted on all executors, and on all nodes if removeFromDriver // is true. In the latter case, also verify that the broadcast file is deleted on the driver. - def afterUnpersist(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) { - assert(blockIds.size === 1) - val statuses = bmm.getBlockStatus(blockIds.head, askSlaves = true) + def afterUnpersist(broadcastId: Long, bmm: BlockManagerMaster) { + val blockId = BroadcastBlockId(broadcastId) + val statuses = bmm.getBlockStatus(blockId, askSlaves = true) val expectedNumBlocks = if (removeFromDriver) 0 else 1 val possiblyNot = if (removeFromDriver) "" else " not" assert(statuses.size === expectedNumBlocks, "Block should%s be unpersisted on the driver".format(possiblyNot)) if (distributed && removeFromDriver) { // this file is only generated in distributed mode - assert(!HttpBroadcast.getFile(blockIds.head.broadcastId).exists, + assert(!HttpBroadcast.getFile(blockId.broadcastId).exists, "Broadcast file should%s be deleted".format(possiblyNot)) } } - testUnpersistBroadcast(distributed, numSlaves, httpConf, getBlockIds, afterCreation, + testUnpersistBroadcast(distributed, numSlaves, httpConf, afterCreation, afterUsingBroadcast, afterUnpersist, removeFromDriver) } @@ -185,67 +185,51 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { private def testUnpersistTorrentBroadcast(distributed: Boolean, removeFromDriver: Boolean) { val numSlaves = if (distributed) 2 else 0 - def getBlockIds(id: Long) = { - val broadcastBlockId = BroadcastBlockId(id) - val metaBlockId = BroadcastBlockId(id, "meta") - // Assume broadcast value is small enough to fit into 1 piece - val pieceBlockId = BroadcastBlockId(id, "piece0") - if (distributed) { - // the metadata and piece blocks are generated only in distributed mode - Seq[BroadcastBlockId](broadcastBlockId, metaBlockId, pieceBlockId) - } else { - Seq[BroadcastBlockId](broadcastBlockId) - } + // Verify that blocks are persisted only on the driver + def afterCreation(broadcastId: Long, bmm: BlockManagerMaster) { + var blockId = BroadcastBlockId(broadcastId) + var statuses = bmm.getBlockStatus(blockId, askSlaves = true) + assert(statuses.size === 1) + + blockId = BroadcastBlockId(broadcastId, "piece0") + statuses = bmm.getBlockStatus(blockId, askSlaves = true) + assert(statuses.size === (if (distributed) 1 else 0)) } - // Verify that blocks are persisted only on the driver - def afterCreation(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) { - blockIds.foreach { blockId => - val statuses = bmm.getBlockStatus(blockIds.head, askSlaves = true) + // Verify that blocks are persisted in both the executors and the driver + def afterUsingBroadcast(broadcastId: Long, bmm: BlockManagerMaster) { + var blockId = BroadcastBlockId(broadcastId) + var statuses = bmm.getBlockStatus(blockId, askSlaves = true) + if (distributed) { + assert(statuses.size === numSlaves + 1) + } else { assert(statuses.size === 1) - statuses.head match { case (bm, status) => - assert(bm.executorId === "", "Block should only be on the driver") - assert(status.storageLevel === StorageLevel.MEMORY_AND_DISK) - assert(status.memSize > 0, "Block should be in memory store on the driver") - assert(status.diskSize === 0, "Block should not be in disk store on the driver") - } } - } - // Verify that blocks are persisted in both the executors and the driver - def afterUsingBroadcast(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) { - blockIds.foreach { blockId => - val statuses = bmm.getBlockStatus(blockId, askSlaves = true) - if (blockId.field == "meta") { - // Meta data is only on the driver - assert(statuses.size === 1) - statuses.head match { case (bm, _) => assert(bm.executorId === "") } - } else { - // Other blocks are on both the executors and the driver - assert(statuses.size === numSlaves + 1, - blockId + " has " + statuses.size + " statuses: " + statuses.mkString(",")) - statuses.foreach { case (_, status) => - assert(status.storageLevel === StorageLevel.MEMORY_AND_DISK) - assert(status.memSize > 0, "Block should be in memory store") - assert(status.diskSize === 0, "Block should not be in disk store") - } - } + blockId = BroadcastBlockId(broadcastId, "piece0") + statuses = bmm.getBlockStatus(blockId, askSlaves = true) + if (distributed) { + assert(statuses.size === numSlaves + 1) + } else { + assert(statuses.size === 0) } } // Verify that blocks are unpersisted on all executors, and on all nodes if removeFromDriver // is true. - def afterUnpersist(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) { - val expectedNumBlocks = if (removeFromDriver) 0 else 1 - val possiblyNot = if (removeFromDriver) "" else " not" - blockIds.foreach { blockId => - val statuses = bmm.getBlockStatus(blockId, askSlaves = true) - assert(statuses.size === expectedNumBlocks, - "Block should%s be unpersisted on the driver".format(possiblyNot)) - } + def afterUnpersist(broadcastId: Long, bmm: BlockManagerMaster) { + var blockId = BroadcastBlockId(broadcastId) + var expectedNumBlocks = if (removeFromDriver) 0 else 1 + var statuses = bmm.getBlockStatus(blockId, askSlaves = true) + assert(statuses.size === expectedNumBlocks) + + blockId = BroadcastBlockId(broadcastId, "piece0") + expectedNumBlocks = if (removeFromDriver || !distributed) 0 else 1 + statuses = bmm.getBlockStatus(blockId, askSlaves = true) + assert(statuses.size === expectedNumBlocks) } - testUnpersistBroadcast(distributed, numSlaves, torrentConf, getBlockIds, afterCreation, + testUnpersistBroadcast(distributed, numSlaves, torrentConf, afterCreation, afterUsingBroadcast, afterUnpersist, removeFromDriver) } @@ -262,10 +246,9 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { distributed: Boolean, numSlaves: Int, // used only when distributed = true broadcastConf: SparkConf, - getBlockIds: Long => Seq[BroadcastBlockId], - afterCreation: (Seq[BroadcastBlockId], BlockManagerMaster) => Unit, - afterUsingBroadcast: (Seq[BroadcastBlockId], BlockManagerMaster) => Unit, - afterUnpersist: (Seq[BroadcastBlockId], BlockManagerMaster) => Unit, + afterCreation: (Long, BlockManagerMaster) => Unit, + afterUsingBroadcast: (Long, BlockManagerMaster) => Unit, + afterUnpersist: (Long, BlockManagerMaster) => Unit, removeFromDriver: Boolean) { sc = if (distributed) { @@ -278,15 +261,14 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { // Create broadcast variable val broadcast = sc.broadcast(list) - val blocks = getBlockIds(broadcast.id) - afterCreation(blocks, blockManagerMaster) + afterCreation(broadcast.id, blockManagerMaster) // Use broadcast variable on all executors val partitions = 10 assert(partitions > numSlaves) val results = sc.parallelize(1 to partitions, partitions).map(x => (x, broadcast.value.sum)) assert(results.collect().toSet === (1 to partitions).map(x => (x, list.sum)).toSet) - afterUsingBroadcast(blocks, blockManagerMaster) + afterUsingBroadcast(broadcast.id, blockManagerMaster) // Unpersist broadcast if (removeFromDriver) { @@ -294,7 +276,7 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { } else { broadcast.unpersist(blocking = true) } - afterUnpersist(blocks, blockManagerMaster) + afterUnpersist(broadcast.id, blockManagerMaster) // If the broadcast is removed from driver, all subsequent uses of the broadcast variable // should throw SparkExceptions. Otherwise, the result should be the same as before. From a5bc9c601e9093b3b896563d23bb2e4add1f0676 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Tue, 19 Aug 2014 22:16:22 -0700 Subject: [PATCH 204/231] [SPARK-3142][MLLIB] output shuffle data directly in Word2Vec Sorry I didn't realize this in #2043. Ishiihara Author: Xiangrui Meng Closes #2049 from mengxr/more-w2v and squashes the following commits: 050b1c5 [Xiangrui Meng] output shuffle data directly (cherry picked from commit 0a984aa155fb7f532fe87620dcf1a2814c5b8b49) Signed-off-by: Xiangrui Meng --- .../apache/spark/mllib/feature/Word2Vec.scala | 23 ++++++++++--------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index c3375ed44fd99..fc1444705364a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -347,19 +347,20 @@ class Word2Vec extends Serializable with Logging { } val syn0Local = model._1 val syn1Local = model._2 - val synOut = mutable.ListBuffer.empty[(Int, Array[Float])] - var index = 0 - while(index < vocabSize) { - if (syn0Modify(index) != 0) { - synOut += ((index, syn0Local.slice(index * vectorSize, (index + 1) * vectorSize))) + // Only output modified vectors. + Iterator.tabulate(vocabSize) { index => + if (syn0Modify(index) > 0) { + Some((index, syn0Local.slice(index * vectorSize, (index + 1) * vectorSize))) + } else { + None } - if (syn1Modify(index) != 0) { - synOut += ((index + vocabSize, - syn1Local.slice(index * vectorSize, (index + 1) * vectorSize))) + }.flatten ++ Iterator.tabulate(vocabSize) { index => + if (syn1Modify(index) > 0) { + Some((index + vocabSize, syn1Local.slice(index * vectorSize, (index + 1) * vectorSize))) + } else { + None } - index += 1 - } - synOut.toIterator + }.flatten } val synAgg = partial.reduceByKey { case (v1, v2) => blas.saxpy(vectorSize, 1.0f, v2, 1, v1, 1) From 5d1a8786686705ae494f60a47c3a9c2e0ce8ff14 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 19 Aug 2014 22:42:50 -0700 Subject: [PATCH 205/231] [SPARK-2974] [SPARK-2975] Fix two bugs related to spark.local.dirs This PR fixes two bugs related to `spark.local.dirs` and `SPARK_LOCAL_DIRS`, one where `Utils.getLocalDir()` might return an invalid directory (SPARK-2974) and another where the `SPARK_LOCAL_DIRS` override didn't affect the driver, which could cause problems when running tasks in local mode (SPARK-2975). This patch fixes both issues: the new `Utils.getOrCreateLocalRootDirs(conf: SparkConf)` utility method manages the creation of local directories and handles the precedence among the different configuration options, so we should see the same behavior whether we're running in local mode or on a worker. It's kind of a pain to mock out environment variables in tests (no easy way to mock System.getenv), so I added a `private[spark]` method to SparkConf for accessing environment variables (by default, it just delegates to System.getenv). By subclassing SparkConf and overriding this method, we can mock out SPARK_LOCAL_DIRS in tests. I also fixed a typo in PySpark where we used `SPARK_LOCAL_DIR` instead of `SPARK_LOCAL_DIRS` (I think this was technically innocuous, but it seemed worth fixing). Author: Josh Rosen Closes #2002 from JoshRosen/local-dirs and squashes the following commits: efad8c6 [Josh Rosen] Address review comments: 1dec709 [Josh Rosen] Minor updates to Javadocs. 7f36999 [Josh Rosen] Use env vars to detect if running in YARN container. 399ac25 [Josh Rosen] Update getLocalDir() documentation. bb3ad89 [Josh Rosen] Remove duplicated YARN getLocalDirs() code. 3e92d44 [Josh Rosen] Move local dirs override logic into Utils; fix bugs: b2c4736 [Josh Rosen] Add failing tests for SPARK-2974 and SPARK-2975. 007298b [Josh Rosen] Allow environment variables to be mocked in tests. 6d9259b [Josh Rosen] Fix typo in PySpark: SPARK_LOCAL_DIR should be SPARK_LOCAL_DIRS (cherry picked from commit ebcb94f701273b56851dade677e047388a8bca09) Signed-off-by: Patrick Wendell --- .../scala/org/apache/spark/SparkConf.scala | 8 ++- .../apache/spark/api/python/PythonRDD.scala | 2 +- .../org/apache/spark/executor/Executor.scala | 25 ------- .../apache/spark/storage/BlockManager.scala | 3 +- .../spark/storage/DiskBlockManager.scala | 14 ++-- .../scala/org/apache/spark/util/Utils.scala | 67 +++++++++++++++++-- .../spark/storage/BlockManagerSuite.scala | 3 +- .../spark/storage/DiskBlockManagerSuite.scala | 4 +- .../apache/spark/storage/LocalDirsSuite.scala | 61 +++++++++++++++++ python/pyspark/shuffle.py | 2 +- .../spark/deploy/yarn/ApplicationMaster.scala | 18 ----- .../spark/deploy/yarn/ExecutorLauncher.scala | 19 ------ .../spark/deploy/yarn/ApplicationMaster.scala | 18 ----- .../spark/deploy/yarn/ExecutorLauncher.scala | 19 ------ 14 files changed, 145 insertions(+), 118 deletions(-) create mode 100644 core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index b4f321ec99e78..605df0e929faa 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -45,7 +45,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { /** Create a SparkConf that loads defaults from system properties and the classpath */ def this() = this(true) - private val settings = new HashMap[String, String]() + private[spark] val settings = new HashMap[String, String]() if (loadDefaults) { // Load any spark.* system properties @@ -210,6 +210,12 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { new SparkConf(false).setAll(settings) } + /** + * By using this instead of System.getenv(), environment variables can be mocked + * in unit tests. + */ + private[spark] def getenv(name: String): String = System.getenv(name) + /** Checks for illegal or deprecated config settings. Throws an exception for the former. Not * idempotent - may mutate this conf object to convert deprecated settings to supported ones. */ private[spark] def validateSettings() { diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 10210a2927dcc..747023812f754 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -62,7 +62,7 @@ private[spark] class PythonRDD( val env = SparkEnv.get val localdir = env.blockManager.diskBlockManager.localDirs.map( f => f.getPath()).mkString(",") - envVars += ("SPARK_LOCAL_DIR" -> localdir) // it's also used in monitor thread + envVars += ("SPARK_LOCAL_DIRS" -> localdir) // it's also used in monitor thread val worker: Socket = env.createPythonWorker(pythonExec, envVars.toMap) // Start a thread to feed the process input from our parent's iterator diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index fb3f7bd54bbfa..2f76e532aeb76 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -62,16 +62,6 @@ private[spark] class Executor( val conf = new SparkConf(true) conf.setAll(properties) - // If we are in yarn mode, systems can have different disk layouts so we must set it - // to what Yarn on this system said was available. This will be used later when SparkEnv - // created. - if (java.lang.Boolean.valueOf( - System.getProperty("SPARK_YARN_MODE", System.getenv("SPARK_YARN_MODE")))) { - conf.set("spark.local.dir", getYarnLocalDirs()) - } else if (sys.env.contains("SPARK_LOCAL_DIRS")) { - conf.set("spark.local.dir", sys.env("SPARK_LOCAL_DIRS")) - } - if (!isLocal) { // Setup an uncaught exception handler for non-local mode. // Make any thread terminations due to uncaught exceptions kill the entire @@ -134,21 +124,6 @@ private[spark] class Executor( threadPool.shutdown() } - /** Get the Yarn approved local directories. */ - private def getYarnLocalDirs(): String = { - // Hadoop 0.23 and 2.x have different Environment variable names for the - // local dirs, so lets check both. We assume one of the 2 is set. - // LOCAL_DIRS => 2.X, YARN_LOCAL_DIRS => 0.23.X - val localDirs = Option(System.getenv("YARN_LOCAL_DIRS")) - .getOrElse(Option(System.getenv("LOCAL_DIRS")) - .getOrElse("")) - - if (localDirs.isEmpty) { - throw new Exception("Yarn Local dirs can't be empty") - } - localDirs - } - class TaskRunner( execBackend: ExecutorBackend, val taskId: Long, taskName: String, serializedTask: ByteBuffer) extends Runnable { diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index c0491fb55e3a4..12a92d44f4c36 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -67,8 +67,7 @@ private[spark] class BlockManager( private val port = conf.getInt("spark.blockManager.port", 0) val shuffleBlockManager = new ShuffleBlockManager(this, shuffleManager) - val diskBlockManager = new DiskBlockManager(shuffleBlockManager, - conf.get("spark.local.dir", System.getProperty("java.io.tmpdir"))) + val diskBlockManager = new DiskBlockManager(shuffleBlockManager, conf) val connectionManager = new ConnectionManager(port, conf, securityManager, "Connection manager for block manager") diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala index f3da816389581..ec022ce9c048a 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala @@ -21,7 +21,7 @@ import java.io.File import java.text.SimpleDateFormat import java.util.{Date, Random, UUID} -import org.apache.spark.{SparkEnv, Logging} +import org.apache.spark.{SparkConf, SparkEnv, Logging} import org.apache.spark.executor.ExecutorExitCode import org.apache.spark.network.netty.PathResolver import org.apache.spark.util.Utils @@ -33,9 +33,10 @@ import org.apache.spark.shuffle.sort.SortShuffleManager * However, it is also possible to have a block map to only a segment of a file, by calling * mapBlockToFileSegment(). * - * @param rootDirs The directories to use for storing block files. Data will be hashed among these. + * Block files are hashed among the directories listed in spark.local.dir (or in + * SPARK_LOCAL_DIRS, if it's set). */ -private[spark] class DiskBlockManager(shuffleBlockManager: ShuffleBlockManager, rootDirs: String) +private[spark] class DiskBlockManager(shuffleBlockManager: ShuffleBlockManager, conf: SparkConf) extends PathResolver with Logging { private val MAX_DIR_CREATION_ATTEMPTS: Int = 10 @@ -46,7 +47,7 @@ private[spark] class DiskBlockManager(shuffleBlockManager: ShuffleBlockManager, /* Create one local directory for each path mentioned in spark.local.dir; then, inside this * directory, create multiple subdirectories that we will hash files into, in order to avoid * having really large inodes at the top level. */ - val localDirs: Array[File] = createLocalDirs() + val localDirs: Array[File] = createLocalDirs(conf) if (localDirs.isEmpty) { logError("Failed to create any local dir.") System.exit(ExecutorExitCode.DISK_STORE_FAILED_TO_CREATE_DIR) @@ -130,10 +131,9 @@ private[spark] class DiskBlockManager(shuffleBlockManager: ShuffleBlockManager, (blockId, getFile(blockId)) } - private def createLocalDirs(): Array[File] = { - logDebug(s"Creating local directories at root dirs '$rootDirs'") + private def createLocalDirs(conf: SparkConf): Array[File] = { val dateFormat = new SimpleDateFormat("yyyyMMddHHmmss") - rootDirs.split(",").flatMap { rootDir => + Utils.getOrCreateLocalRootDirs(conf).flatMap { rootDir => var foundLocalDir = false var localDir: File = null var localDirId: String = null diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 019f68b160894..d6d74ce269219 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -449,12 +449,71 @@ private[spark] object Utils extends Logging { } /** - * Get a temporary directory using Spark's spark.local.dir property, if set. This will always - * return a single directory, even though the spark.local.dir property might be a list of - * multiple paths. + * Get the path of a temporary directory. Spark's local directories can be configured through + * multiple settings, which are used with the following precedence: + * + * - If called from inside of a YARN container, this will return a directory chosen by YARN. + * - If the SPARK_LOCAL_DIRS environment variable is set, this will return a directory from it. + * - Otherwise, if the spark.local.dir is set, this will return a directory from it. + * - Otherwise, this will return java.io.tmpdir. + * + * Some of these configuration options might be lists of multiple paths, but this method will + * always return a single directory. */ def getLocalDir(conf: SparkConf): String = { - conf.get("spark.local.dir", System.getProperty("java.io.tmpdir")).split(',')(0) + getOrCreateLocalRootDirs(conf)(0) + } + + private[spark] def isRunningInYarnContainer(conf: SparkConf): Boolean = { + // These environment variables are set by YARN. + // For Hadoop 0.23.X, we check for YARN_LOCAL_DIRS (we use this below in getYarnLocalDirs()) + // For Hadoop 2.X, we check for CONTAINER_ID. + conf.getenv("CONTAINER_ID") != null || conf.getenv("YARN_LOCAL_DIRS") != null + } + + /** + * Gets or creates the directories listed in spark.local.dir or SPARK_LOCAL_DIRS, + * and returns only the directories that exist / could be created. + * + * If no directories could be created, this will return an empty list. + */ + private[spark] def getOrCreateLocalRootDirs(conf: SparkConf): Array[String] = { + val confValue = if (isRunningInYarnContainer(conf)) { + // If we are in yarn mode, systems can have different disk layouts so we must set it + // to what Yarn on this system said was available. + getYarnLocalDirs(conf) + } else { + Option(conf.getenv("SPARK_LOCAL_DIRS")).getOrElse( + conf.get("spark.local.dir", System.getProperty("java.io.tmpdir"))) + } + val rootDirs = confValue.split(',') + logDebug(s"Getting/creating local root dirs at '$confValue'") + + rootDirs.flatMap { rootDir => + val localDir: File = new File(rootDir) + val foundLocalDir = localDir.exists || localDir.mkdirs() + if (!foundLocalDir) { + logError(s"Failed to create local root dir in $rootDir. Ignoring this directory.") + None + } else { + Some(rootDir) + } + } + } + + /** Get the Yarn approved local directories. */ + private def getYarnLocalDirs(conf: SparkConf): String = { + // Hadoop 0.23 and 2.x have different Environment variable names for the + // local dirs, so lets check both. We assume one of the 2 is set. + // LOCAL_DIRS => 2.X, YARN_LOCAL_DIRS => 0.23.X + val localDirs = Option(conf.getenv("YARN_LOCAL_DIRS")) + .getOrElse(Option(conf.getenv("LOCAL_DIRS")) + .getOrElse("")) + + if (localDirs.isEmpty) { + throw new Exception("Yarn Local dirs can't be empty") + } + localDirs } /** diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 20bac66105a69..f32ce6f9fcc7f 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -825,8 +825,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter val blockManager = mock(classOf[BlockManager]) val shuffleBlockManager = mock(classOf[ShuffleBlockManager]) when(shuffleBlockManager.conf).thenReturn(conf) - val diskBlockManager = new DiskBlockManager(shuffleBlockManager, - System.getProperty("java.io.tmpdir")) + val diskBlockManager = new DiskBlockManager(shuffleBlockManager, conf) when(blockManager.conf).thenReturn(conf.clone.set(confKey, 0.toString)) val diskStoreMapped = new DiskStore(blockManager, diskBlockManager) diff --git a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala index 777579bc570db..aabaeadd7a071 100644 --- a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala @@ -71,7 +71,9 @@ class DiskBlockManagerSuite extends FunSuite with BeforeAndAfterEach with Before } override def beforeEach() { - diskBlockManager = new DiskBlockManager(shuffleBlockManager, rootDirs) + val conf = testConf.clone + conf.set("spark.local.dir", rootDirs) + diskBlockManager = new DiskBlockManager(shuffleBlockManager, conf) shuffleBlockManager.idToSegmentMap.clear() } diff --git a/core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala b/core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala new file mode 100644 index 0000000000000..dae7bf0e336de --- /dev/null +++ b/core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import java.io.File + +import org.apache.spark.util.Utils +import org.scalatest.FunSuite + +import org.apache.spark.SparkConf + + +/** + * Tests for the spark.local.dir and SPARK_LOCAL_DIRS configuration options. + */ +class LocalDirsSuite extends FunSuite { + + test("Utils.getLocalDir() returns a valid directory, even if some local dirs are missing") { + // Regression test for SPARK-2974 + assert(!new File("/NONEXISTENT_DIR").exists()) + val conf = new SparkConf(false) + .set("spark.local.dir", s"/NONEXISTENT_PATH,${System.getProperty("java.io.tmpdir")}") + assert(new File(Utils.getLocalDir(conf)).exists()) + } + + test("SPARK_LOCAL_DIRS override also affects driver") { + // Regression test for SPARK-2975 + assert(!new File("/NONEXISTENT_DIR").exists()) + // SPARK_LOCAL_DIRS is a valid directory: + class MySparkConf extends SparkConf(false) { + override def getenv(name: String) = { + if (name == "SPARK_LOCAL_DIRS") System.getProperty("java.io.tmpdir") + else super.getenv(name) + } + + override def clone: SparkConf = { + new MySparkConf().setAll(settings) + } + } + // spark.local.dir only contains invalid directories, but that's not a problem since + // SPARK_LOCAL_DIRS will override it on both the driver and workers: + val conf = new MySparkConf().set("spark.local.dir", "/NONEXISTENT_PATH") + assert(new File(Utils.getLocalDir(conf)).exists()) + } + +} diff --git a/python/pyspark/shuffle.py b/python/pyspark/shuffle.py index 2c68cd4921deb..1ebe7df418327 100644 --- a/python/pyspark/shuffle.py +++ b/python/pyspark/shuffle.py @@ -214,7 +214,7 @@ def __init__(self, aggregator, memory_limit=512, serializer=None, def _get_dirs(self): """ Get all the directories """ - path = os.environ.get("SPARK_LOCAL_DIR", "/tmp") + path = os.environ.get("SPARK_LOCAL_DIRS", "/tmp") dirs = path.split(",") return [os.path.join(d, "python", str(os.getpid()), str(id(self))) for d in dirs] diff --git a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 46a01f5a9a2cc..4d4848b1bd8f8 100644 --- a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -72,10 +72,6 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, private var registered = false def run() { - // Setup the directories so things go to yarn approved directories rather - // then user specified and /tmp. - System.setProperty("spark.local.dir", getLocalDirs()) - // set the web ui port to be ephemeral for yarn so we don't conflict with // other spark processes running on the same box System.setProperty("spark.ui.port", "0") @@ -138,20 +134,6 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, params) } - /** Get the Yarn approved local directories. */ - private def getLocalDirs(): String = { - // Hadoop 0.23 and 2.x have different Environment variable names for the - // local dirs, so lets check both. We assume one of the 2 is set. - // LOCAL_DIRS => 2.X, YARN_LOCAL_DIRS => 0.23.X - val localDirs = Option(System.getenv("YARN_LOCAL_DIRS")) - .orElse(Option(System.getenv("LOCAL_DIRS"))) - - localDirs match { - case None => throw new Exception("Yarn Local dirs can't be empty") - case Some(l) => l - } - } - private def getApplicationAttemptId(): ApplicationAttemptId = { val envs = System.getenv() val containerIdString = envs.get(ApplicationConstants.AM_CONTAINER_ID_ENV) diff --git a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala index 72c7143edcd71..c3310fbc24a98 100644 --- a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala +++ b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala @@ -95,11 +95,6 @@ class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sp } def run() { - - // Setup the directories so things go to yarn approved directories rather - // then user specified and /tmp. - System.setProperty("spark.local.dir", getLocalDirs()) - appAttemptId = getApplicationAttemptId() resourceManager = registerWithResourceManager() @@ -152,20 +147,6 @@ class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sp System.exit(0) } - /** Get the Yarn approved local directories. */ - private def getLocalDirs(): String = { - // Hadoop 0.23 and 2.x have different Environment variable names for the - // local dirs, so lets check both. We assume one of the 2 is set. - // LOCAL_DIRS => 2.X, YARN_LOCAL_DIRS => 0.23.X - val localDirs = Option(System.getenv("YARN_LOCAL_DIRS")) - .orElse(Option(System.getenv("LOCAL_DIRS"))) - - localDirs match { - case None => throw new Exception("Yarn Local dirs can't be empty") - case Some(l) => l - } - } - private def getApplicationAttemptId(): ApplicationAttemptId = { val envs = System.getenv() val containerIdString = envs.get(ApplicationConstants.AM_CONTAINER_ID_ENV) diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 9c2bcf17a8508..1c4005fd8e78e 100644 --- a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -72,10 +72,6 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, private var registered = false def run() { - // Setup the directories so things go to YARN approved directories rather - // than user specified and /tmp. - System.setProperty("spark.local.dir", getLocalDirs()) - // Set the web ui port to be ephemeral for yarn so we don't conflict with // other spark processes running on the same box System.setProperty("spark.ui.port", "0") @@ -144,20 +140,6 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, "spark.org.apache.hadoop.yarn.server.webproxy.amfilter.AmIpFilter.params", params) } - // Get the Yarn approved local directories. - private def getLocalDirs(): String = { - // Hadoop 0.23 and 2.x have different Environment variable names for the - // local dirs, so lets check both. We assume one of the 2 is set. - // LOCAL_DIRS => 2.X, YARN_LOCAL_DIRS => 0.23.X - val localDirs = Option(System.getenv("YARN_LOCAL_DIRS")) - .orElse(Option(System.getenv("LOCAL_DIRS"))) - - localDirs match { - case None => throw new Exception("Yarn local dirs can't be empty") - case Some(l) => l - } - } - private def registerApplicationMaster(): RegisterApplicationMasterResponse = { logInfo("Registering the ApplicationMaster") amClient.registerApplicationMaster(Utils.localHostName(), 0, uiAddress) diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala index a7585748b7f88..45925f1fea005 100644 --- a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala +++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala @@ -94,11 +94,6 @@ class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sp } def run() { - - // Setup the directories so things go to yarn approved directories rather - // then user specified and /tmp. - System.setProperty("spark.local.dir", getLocalDirs()) - amClient = AMRMClient.createAMRMClient() amClient.init(yarnConf) amClient.start() @@ -141,20 +136,6 @@ class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sp System.exit(0) } - /** Get the Yarn approved local directories. */ - private def getLocalDirs(): String = { - // Hadoop 0.23 and 2.x have different Environment variable names for the - // local dirs, so lets check both. We assume one of the 2 is set. - // LOCAL_DIRS => 2.X, YARN_LOCAL_DIRS => 0.23.X - val localDirs = Option(System.getenv("YARN_LOCAL_DIRS")) - .orElse(Option(System.getenv("LOCAL_DIRS"))) - - localDirs match { - case None => throw new Exception("Yarn Local dirs can't be empty") - case Some(l) => l - } - } - private def registerApplicationMaster(): RegisterApplicationMasterResponse = { val appUIAddress = sparkConf.get("spark.driver.appUIAddress", "") logInfo(s"Registering the ApplicationMaster with appUIAddress: $appUIAddress") From f8c908ebfebb4b7a09dec6c806732997a73c1b84 Mon Sep 17 00:00:00 2001 From: Ken Takagiwa Date: Tue, 19 Aug 2014 22:43:22 -0700 Subject: [PATCH 206/231] [DOCS] Fixed wrong links Author: Ken Takagiwa Closes #2042 from giwa/patch-1 and squashes the following commits: 216fe0e [Ken Takagiwa] Fixed wrong links (cherry picked from commit 8a74e4b2a8c7dab154b406539487cf29d578d208) Signed-off-by: Reynold Xin --- docs/streaming-custom-receivers.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/streaming-custom-receivers.md b/docs/streaming-custom-receivers.md index 1e045a3dd0ca9..27cd085782f66 100644 --- a/docs/streaming-custom-receivers.md +++ b/docs/streaming-custom-receivers.md @@ -186,7 +186,7 @@ JavaDStream words = lines.flatMap(new FlatMapFunction() ... {% endhighlight %} -The full source code is in the example [JavaCustomReceiver.java](https://github.com/apache/spark/blob/master/examples/src/main/java/org/apache/spark/streaming/examples/JavaCustomReceiver.java). +The full source code is in the example [JavaCustomReceiver.java](https://github.com/apache/spark/blob/master/examples/src/main/java/org/apache/spark/examples/streaming/JavaCustomReceiver.java).
    @@ -215,7 +215,7 @@ And a new input stream can be created with this custom actor as val lines = ssc.actorStream[String](Props(new CustomActor()), "CustomReceiver") {% endhighlight %} -See [ActorWordCount.scala](https://github.com/apache/spark/blob/master/examples/src/main/scala/org/apache/spark/streaming/examples/ActorWordCount.scala) +See [ActorWordCount.scala](https://github.com/apache/spark/blob/master/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala) for an end-to-end example. From 5b22ebf68bdf7ac537999abb0e7d18c18ad8d0b0 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 19 Aug 2014 22:43:49 -0700 Subject: [PATCH 207/231] [SPARK-3141] [PySpark] fix sortByKey() with take() Fix sortByKey() with take() The function `f` used in mapPartitions should always return an iterator. Author: Davies Liu Closes #2045 from davies/fix_sortbykey and squashes the following commits: 1160f59 [Davies Liu] fix sortByKey() with take() (cherry picked from commit 0a7ef6339f18e68d703599aff7db2dd9c2003866) Signed-off-by: Patrick Wendell --- python/pyspark/rdd.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 140cbe05a43b0..3eefc878d274e 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -575,6 +575,8 @@ def sortByKey(self, ascending=True, numPartitions=None, keyfunc=lambda x: x): # noqa >>> tmp = [('a', 1), ('b', 2), ('1', 3), ('d', 4), ('2', 5)] + >>> sc.parallelize(tmp).sortByKey().first() + ('1', 3) >>> sc.parallelize(tmp).sortByKey(True, 1).collect() [('1', 3), ('2', 5), ('a', 1), ('b', 2), ('d', 4)] >>> sc.parallelize(tmp).sortByKey(True, 2).collect() @@ -587,14 +589,13 @@ def sortByKey(self, ascending=True, numPartitions=None, keyfunc=lambda x: x): if numPartitions is None: numPartitions = self._defaultReducePartitions() + def sortPartition(iterator): + return iter(sorted(iterator, key=lambda (k, v): keyfunc(k), reverse=not ascending)) + if numPartitions == 1: if self.getNumPartitions() > 1: self = self.coalesce(1) - - def sort(iterator): - return sorted(iterator, reverse=(not ascending), key=lambda (k, v): keyfunc(k)) - - return self.mapPartitions(sort) + return self.mapPartitions(sortPartition) # first compute the boundary of each part via sampling: we want to partition # the key-space into bins such that the bins have roughly the same @@ -610,17 +611,14 @@ def sort(iterator): bounds = [samples[len(samples) * (i + 1) / numPartitions] for i in range(0, numPartitions - 1)] - def rangePartitionFunc(k): + def rangePartitioner(k): p = bisect.bisect_left(bounds, keyfunc(k)) if ascending: return p else: return numPartitions - 1 - p - def mapFunc(iterator): - return sorted(iterator, reverse=(not ascending), key=lambda (k, v): keyfunc(k)) - - return self.partitionBy(numPartitions, rangePartitionFunc).mapPartitions(mapFunc, True) + return self.partitionBy(numPartitions, rangePartitioner).mapPartitions(sortPartition, True) def sortBy(self, keyfunc, ascending=True, numPartitions=None): """ From 9b29099557596356c2ae6baa82afc899c8a557f2 Mon Sep 17 00:00:00 2001 From: Hari Shreedharan Date: Wed, 20 Aug 2014 04:09:54 -0700 Subject: [PATCH 208/231] [SPARK-3054][STREAMING] Add unit tests for Spark Sink. This patch adds unit tests for Spark Sink. It also removes the private[flume] for Spark Sink, since the sink is instantiated from Flume configuration (looks like this is ignored by reflection which is used by Flume, but we should still remove it anyway). Author: Hari Shreedharan Author: Hari Shreedharan Closes #1958 from harishreedharan/spark-sink-test and squashes the following commits: e3110b9 [Hari Shreedharan] Add a sleep to allow sink to commit the transactions 120b81e [Hari Shreedharan] Fix complexity in threading model in test 4df5be6 [Hari Shreedharan] Merge remote-tracking branch 'asf/master' into spark-sink-test c9190d1 [Hari Shreedharan] Indentation and spaces changes 7fedc5a [Hari Shreedharan] Merge remote-tracking branch 'asf/master' into spark-sink-test abc20cb [Hari Shreedharan] Minor test changes 7b9b649 [Hari Shreedharan] Merge branch 'master' into spark-sink-test f2c56c9 [Hari Shreedharan] Update SparkSinkSuite.scala a24aac8 [Hari Shreedharan] Remove unused var c86d615 [Hari Shreedharan] [SPARK-3054][STREAMING] Add unit tests for Spark Sink. (cherry picked from commit 8c5a2226932c572898c76eb6fab9283f02ad4103) Signed-off-by: Tathagata Das --- external/flume-sink/pom.xml | 7 + .../streaming/flume/sink/SparkSink.scala | 1 - .../streaming/flume/sink/SparkSinkSuite.scala | 204 ++++++++++++++++++ .../flume/FlumePollingStreamSuite.scala | 2 +- 4 files changed, 212 insertions(+), 2 deletions(-) create mode 100644 external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala diff --git a/external/flume-sink/pom.xml b/external/flume-sink/pom.xml index cfbf943bdafe0..7f1172ec2092d 100644 --- a/external/flume-sink/pom.xml +++ b/external/flume-sink/pom.xml @@ -72,6 +72,13 @@ org.scalatest scalatest_${scala.binary.version} + + org.apache.spark + spark-streaming_${scala.binary.version} + ${project.version} + test-jar + test + target/scala-${scala.binary.version}/classes diff --git a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSink.scala b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSink.scala index 7b735133e3d14..1a61b36910a95 100644 --- a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSink.scala +++ b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSink.scala @@ -53,7 +53,6 @@ import org.apache.flume.sink.AbstractSink * */ -private[flume] class SparkSink extends AbstractSink with Logging with Configurable { // Size of the pool to use for holding transaction processors. diff --git a/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala b/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala new file mode 100644 index 0000000000000..44b27edf85ce8 --- /dev/null +++ b/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala @@ -0,0 +1,204 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.streaming.flume.sink + +import java.net.InetSocketAddress +import java.util.concurrent.atomic.AtomicInteger +import java.util.concurrent.{TimeUnit, CountDownLatch, Executors} + +import scala.collection.JavaConversions._ +import scala.concurrent.{ExecutionContext, Future} +import scala.util.{Failure, Success} + +import com.google.common.util.concurrent.ThreadFactoryBuilder +import org.apache.avro.ipc.NettyTransceiver +import org.apache.avro.ipc.specific.SpecificRequestor +import org.apache.flume.Context +import org.apache.flume.channel.MemoryChannel +import org.apache.flume.event.EventBuilder +import org.apache.spark.streaming.TestSuiteBase +import org.jboss.netty.channel.socket.nio.NioClientSocketChannelFactory + +class SparkSinkSuite extends TestSuiteBase { + val eventsPerBatch = 1000 + val channelCapacity = 5000 + + test("Success") { + val (channel, sink) = initializeChannelAndSink() + channel.start() + sink.start() + + putEvents(channel, eventsPerBatch) + + val port = sink.getPort + val address = new InetSocketAddress("0.0.0.0", port) + + val (transceiver, client) = getTransceiverAndClient(address, 1)(0) + val events = client.getEventBatch(1000) + client.ack(events.getSequenceNumber) + assert(events.getEvents.size() === 1000) + assertChannelIsEmpty(channel) + sink.stop() + channel.stop() + transceiver.close() + } + + test("Nack") { + val (channel, sink) = initializeChannelAndSink() + channel.start() + sink.start() + putEvents(channel, eventsPerBatch) + + val port = sink.getPort + val address = new InetSocketAddress("0.0.0.0", port) + + val (transceiver, client) = getTransceiverAndClient(address, 1)(0) + val events = client.getEventBatch(1000) + assert(events.getEvents.size() === 1000) + client.nack(events.getSequenceNumber) + assert(availableChannelSlots(channel) === 4000) + sink.stop() + channel.stop() + transceiver.close() + } + + test("Timeout") { + val (channel, sink) = initializeChannelAndSink(Map(SparkSinkConfig + .CONF_TRANSACTION_TIMEOUT -> 1.toString)) + channel.start() + sink.start() + putEvents(channel, eventsPerBatch) + val port = sink.getPort + val address = new InetSocketAddress("0.0.0.0", port) + + val (transceiver, client) = getTransceiverAndClient(address, 1)(0) + val events = client.getEventBatch(1000) + assert(events.getEvents.size() === 1000) + Thread.sleep(1000) + assert(availableChannelSlots(channel) === 4000) + sink.stop() + channel.stop() + transceiver.close() + } + + test("Multiple consumers") { + testMultipleConsumers(failSome = false) + } + + test("Multiple consumers with some failures") { + testMultipleConsumers(failSome = true) + } + + def testMultipleConsumers(failSome: Boolean): Unit = { + implicit val executorContext = ExecutionContext + .fromExecutorService(Executors.newFixedThreadPool(5)) + val (channel, sink) = initializeChannelAndSink() + channel.start() + sink.start() + (1 to 5).foreach(_ => putEvents(channel, eventsPerBatch)) + val port = sink.getPort + val address = new InetSocketAddress("0.0.0.0", port) + val transceiversAndClients = getTransceiverAndClient(address, 5) + val batchCounter = new CountDownLatch(5) + val counter = new AtomicInteger(0) + transceiversAndClients.foreach(x => { + Future { + val client = x._2 + val events = client.getEventBatch(1000) + if (!failSome || counter.getAndIncrement() % 2 == 0) { + client.ack(events.getSequenceNumber) + } else { + client.nack(events.getSequenceNumber) + throw new RuntimeException("Sending NACK for failure!") + } + events + }.onComplete { + case Success(events) => + assert(events.getEvents.size() === 1000) + batchCounter.countDown() + case Failure(t) => + // Don't re-throw the exception, causes a nasty unnecessary stack trace on stdout + batchCounter.countDown() + } + }) + batchCounter.await() + TimeUnit.SECONDS.sleep(1) // Allow the sink to commit the transactions. + executorContext.shutdown() + if(failSome) { + assert(availableChannelSlots(channel) === 3000) + } else { + assertChannelIsEmpty(channel) + } + sink.stop() + channel.stop() + transceiversAndClients.foreach(x => x._1.close()) + } + + private def initializeChannelAndSink(overrides: Map[String, String] = Map.empty): (MemoryChannel, + SparkSink) = { + val channel = new MemoryChannel() + val channelContext = new Context() + + channelContext.put("capacity", channelCapacity.toString) + channelContext.put("transactionCapacity", 1000.toString) + channelContext.put("keep-alive", 0.toString) + channelContext.putAll(overrides) + channel.configure(channelContext) + + val sink = new SparkSink() + val sinkContext = new Context() + sinkContext.put(SparkSinkConfig.CONF_HOSTNAME, "0.0.0.0") + sinkContext.put(SparkSinkConfig.CONF_PORT, 0.toString) + sink.configure(sinkContext) + sink.setChannel(channel) + (channel, sink) + } + + private def putEvents(ch: MemoryChannel, count: Int): Unit = { + val tx = ch.getTransaction + tx.begin() + (1 to count).foreach(x => ch.put(EventBuilder.withBody(x.toString.getBytes))) + tx.commit() + tx.close() + } + + private def getTransceiverAndClient(address: InetSocketAddress, + count: Int): Seq[(NettyTransceiver, SparkFlumeProtocol.Callback)] = { + + (1 to count).map(_ => { + lazy val channelFactoryExecutor = + Executors.newCachedThreadPool(new ThreadFactoryBuilder().setDaemon(true). + setNameFormat("Flume Receiver Channel Thread - %d").build()) + lazy val channelFactory = + new NioClientSocketChannelFactory(channelFactoryExecutor, channelFactoryExecutor) + val transceiver = new NettyTransceiver(address, channelFactory) + val client = SpecificRequestor.getClient(classOf[SparkFlumeProtocol.Callback], transceiver) + (transceiver, client) + }) + } + + private def assertChannelIsEmpty(channel: MemoryChannel): Unit = { + assert(availableChannelSlots(channel) === channelCapacity) + } + + private def availableChannelSlots(channel: MemoryChannel): Int = { + val queueRemaining = channel.getClass.getDeclaredField("queueRemaining") + queueRemaining.setAccessible(true) + val m = queueRemaining.get(channel).getClass.getDeclaredMethod("availablePermits") + m.invoke(queueRemaining.get(channel)).asInstanceOf[Int] + } +} diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala index 27bf2ac962721..2e4ac7cfbf263 100644 --- a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala +++ b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala @@ -171,7 +171,7 @@ class FlumePollingStreamSuite extends TestSuiteBase { } def assertChannelIsEmpty(channel: MemoryChannel) = { - val queueRemaining = channel.getClass.getDeclaredField("queueRemaining"); + val queueRemaining = channel.getClass.getDeclaredField("queueRemaining") queueRemaining.setAccessible(true) val m = queueRemaining.get(channel).getClass.getDeclaredMethod("availablePermits") assert(m.invoke(queueRemaining.get(channel)).asInstanceOf[Int] === 5000) From ca7322dda10def28b1133876aa9196f555c5025e Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Wed, 20 Aug 2014 12:13:31 -0700 Subject: [PATCH 209/231] SPARK-3092 [SQL]: Always include the thriftserver when -Phive is enabled. Currently we have a separate profile called hive-thriftserver. I originally suggested this in case users did not want to bundle the thriftserver, but it's ultimately lead to a lot of confusion. Since the thriftserver is only a few classes, I don't see a really good reason to isolate it from the rest of Hive. So let's go ahead and just include it in the same profile to simplify things. This has been suggested in the past by liancheng. Author: Patrick Wendell Closes #2006 from pwendell/hiveserver and squashes the following commits: 742ea40 [Patrick Wendell] Merge remote-tracking branch 'apache/master' into hiveserver 034ad47 [Patrick Wendell] SPARK-3092: Always include the thriftserver when -Phive is enabled. (cherry picked from commit f2f26c2a1dc6d60078c3be9c3d11a21866d9a24f) Signed-off-by: Patrick Wendell --- README.md | 6 +----- assembly/pom.xml | 5 ----- dev/create-release/create-release.sh | 10 +++++----- dev/run-tests | 2 +- dev/scalastyle | 2 +- docs/building-with-maven.md | 8 ++------ docs/sql-programming-guide.md | 4 +--- pom.xml | 2 +- 8 files changed, 12 insertions(+), 27 deletions(-) diff --git a/README.md b/README.md index a1a48f5bd0819..8906e4c1416b1 100644 --- a/README.md +++ b/README.md @@ -118,11 +118,7 @@ If your project is built with Maven, add this to your POM file's ` ## A Note About Thrift JDBC server and CLI for Spark SQL Spark SQL supports Thrift JDBC server and CLI. -See sql-programming-guide.md for more information about those features. -You can use those features by setting `-Phive-thriftserver` when building Spark as follows. - - $ sbt/sbt -Phive-thriftserver assembly - +See sql-programming-guide.md for more information about using the JDBC server. ## Configuration diff --git a/assembly/pom.xml b/assembly/pom.xml index 16e5271b35050..4709b7dbddfea 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -163,11 +163,6 @@ spark-hive_${scala.binary.version} ${project.version} - - - - hive-thriftserver - org.apache.spark spark-hive-thriftserver_${scala.binary.version} diff --git a/dev/create-release/create-release.sh b/dev/create-release/create-release.sh index 28f26d2368254..905dec0ced383 100755 --- a/dev/create-release/create-release.sh +++ b/dev/create-release/create-release.sh @@ -60,14 +60,14 @@ if [[ ! "$@" =~ --package-only ]]; then -Dmaven.javadoc.skip=true \ -Dhadoop.version=2.2.0 -Dyarn.version=2.2.0 \ -Dtag=$GIT_TAG -DautoVersionSubmodules=true \ - -Pyarn -Phive -Phive-thriftserver -Phadoop-2.2 -Pspark-ganglia-lgpl -Pkinesis-asl \ + -Pyarn -Phive -Phadoop-2.2 -Pspark-ganglia-lgpl -Pkinesis-asl \ --batch-mode release:prepare mvn -DskipTests \ -Darguments="-DskipTests=true -Dmaven.javadoc.skip=true -Dhadoop.version=2.2.0 -Dyarn.version=2.2.0 -Dgpg.passphrase=${GPG_PASSPHRASE}" \ -Dhadoop.version=2.2.0 -Dyarn.version=2.2.0 \ -Dmaven.javadoc.skip=true \ - -Pyarn -Phive -Phive-thriftserver -Phadoop-2.2 -Pspark-ganglia-lgpl -Pkinesis-asl \ + -Pyarn -Phive -Phadoop-2.2 -Pspark-ganglia-lgpl -Pkinesis-asl \ release:perform cd .. @@ -117,10 +117,10 @@ make_binary_release() { spark-$RELEASE_VERSION-bin-$NAME.tgz.sha } -make_binary_release "hadoop1" "-Phive -Phive-thriftserver -Dhadoop.version=1.0.4" & -make_binary_release "cdh4" "-Phive -Phive-thriftserver -Dhadoop.version=2.0.0-mr1-cdh4.2.0" & +make_binary_release "hadoop1" "-Phive -Dhadoop.version=1.0.4" & +make_binary_release "cdh4" "-Phive -Dhadoop.version=2.0.0-mr1-cdh4.2.0" & make_binary_release "hadoop2" \ - "-Phive -Phive-thriftserver -Pyarn -Phadoop-2.2 -Dhadoop.version=2.2.0 -Pyarn.version=2.2.0" & + "-Phive -Pyarn -Phadoop-2.2 -Dhadoop.version=2.2.0 -Pyarn.version=2.2.0" & make_binary_release "hadoop2-without-hive" \ "-Pyarn -Phadoop-2.2 -Dhadoop.version=2.2.0 -Pyarn.version=2.2.0" & wait diff --git a/dev/run-tests b/dev/run-tests index 132f696d6447a..20a67cfb361b9 100755 --- a/dev/run-tests +++ b/dev/run-tests @@ -99,7 +99,7 @@ echo -e "q\n" | sbt/sbt $BUILD_MVN_PROFILE_ARGS clean package assembly/assembly # If the Spark SQL tests are enabled, run the tests with the Hive profiles enabled: if [ -n "$_RUN_SQL_TESTS" ]; then - SBT_MAVEN_PROFILES_ARGS="$SBT_MAVEN_PROFILES_ARGS -Phive -Phive-thriftserver" + SBT_MAVEN_PROFILES_ARGS="$SBT_MAVEN_PROFILES_ARGS -Phive" fi # echo "q" is needed because sbt on encountering a build file with failure # (either resolution or compilation) prompts the user for input either q, r, diff --git a/dev/scalastyle b/dev/scalastyle index b53053a04ff42..eb9b467965636 100755 --- a/dev/scalastyle +++ b/dev/scalastyle @@ -17,7 +17,7 @@ # limitations under the License. # -echo -e "q\n" | sbt/sbt -Phive -Phive-thriftserver scalastyle > scalastyle.txt +echo -e "q\n" | sbt/sbt -Phive scalastyle > scalastyle.txt # Check style with YARN alpha built too echo -e "q\n" | sbt/sbt -Pyarn -Phadoop-0.23 -Dhadoop.version=0.23.9 yarn-alpha/scalastyle \ >> scalastyle.txt diff --git a/docs/building-with-maven.md b/docs/building-with-maven.md index 4d87ab92cec5b..a7d7bd3ccb1f2 100644 --- a/docs/building-with-maven.md +++ b/docs/building-with-maven.md @@ -98,12 +98,8 @@ mvn -Pyarn-alpha -Phadoop-2.3 -Dhadoop.version=2.3.0 -Dyarn.version=0.23.7 -Dski # Building Thrift JDBC server and CLI for Spark SQL -Spark SQL supports Thrift JDBC server and CLI. -See sql-programming-guide.md for more information about those features. -You can use those features by setting `-Phive-thriftserver` when building Spark as follows. -{% highlight bash %} -mvn -Phive-thriftserver assembly -{% endhighlight %} +Spark SQL supports Thrift JDBC server and CLI. See sql-programming-guide.md for +more information about the JDBC server. # Spark Tests in Maven diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 34accade36ea9..c41f2804a6021 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -578,9 +578,7 @@ evaluated by the SQL execution engine. A full list of the functions supported c The Thrift JDBC server implemented here corresponds to the [`HiveServer2`] (https://cwiki.apache.org/confluence/display/Hive/Setting+Up+HiveServer2) in Hive 0.12. You can test -the JDBC server with the beeline script comes with either Spark or Hive 0.12. In order to use Hive -you must first run '`sbt/sbt -Phive-thriftserver assembly/assembly`' (or use `-Phive-thriftserver` -for maven). +the JDBC server with the beeline script comes with either Spark or Hive 0.12. To start the JDBC server, run the following in the Spark directory: diff --git a/pom.xml b/pom.xml index 8c4c4af0eda8e..1479326af0ed9 100644 --- a/pom.xml +++ b/pom.xml @@ -1178,7 +1178,7 @@ - hive-thriftserver + hive false From 99ca704aba34282d97a8d05bc2b283a4b344bff2 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Wed, 20 Aug 2014 12:57:39 -0700 Subject: [PATCH 210/231] [SPARK-3126][SPARK-3127][SQL] Fixed HiveThriftServer2Suite This PR fixes two issues: 1. Fixes wrongly quoted command line option in `HiveThriftServer2Suite` that makes test cases hang until timeout. 1. Asks `dev/run-test` to run Spark SQL tests when `bin/spark-sql` and/or `sbin/start-thriftserver.sh` are modified. Author: Cheng Lian Closes #2036 from liancheng/fix-thriftserver-test and squashes the following commits: f38c4eb [Cheng Lian] Fixed the same quotation issue in CliSuite 26b82a0 [Cheng Lian] Run SQL tests when dff contains bin/spark-sql and/or sbin/start-thriftserver.sh a87f83d [Cheng Lian] Extended timeout e5aa31a [Cheng Lian] Fixed metastore JDBC URI quotation (cherry picked from commit cf46e725814f575ebb417e80d2571bccc6dac4a7) Signed-off-by: Michael Armbrust --- dev/run-tests | 2 +- .../spark/sql/hive/thriftserver/CliSuite.scala | 2 +- .../thriftserver/HiveThriftServer2Suite.scala | 18 ++++-------------- 3 files changed, 6 insertions(+), 16 deletions(-) diff --git a/dev/run-tests b/dev/run-tests index 20a67cfb361b9..d751961605dfd 100755 --- a/dev/run-tests +++ b/dev/run-tests @@ -55,7 +55,7 @@ JAVA_VERSION=$($java_cmd -version 2>&1 | sed 's/java version "\(.*\)\.\(.*\)\..* # Partial solution for SPARK-1455. Only run Hive tests if there are sql changes. if [ -n "$AMPLAB_JENKINS" ]; then git fetch origin master:master - diffs=`git diff --name-only master | grep "^sql/"` + diffs=`git diff --name-only master | grep "^\(sql/\)\|\(bin/spark-sql\)\|\(sbin/start-thriftserver.sh\)"` if [ -n "$diffs" ]; then echo "Detected changes in SQL. Will run Hive test suite." _RUN_SQL_TESTS=true diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala index 2bf8cfdcacd22..70bea1ed80fda 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala @@ -32,7 +32,7 @@ class CliSuite extends FunSuite with BeforeAndAfterAll with TestUtils { val commands = s"""../../bin/spark-sql | --master local - | --hiveconf ${ConfVars.METASTORECONNECTURLKEY}="$jdbcUrl" + | --hiveconf ${ConfVars.METASTORECONNECTURLKEY}=$jdbcUrl | --hiveconf ${ConfVars.METASTOREWAREHOUSE}=$WAREHOUSE_PATH """.stripMargin.split("\\s+") diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala index aedef6ce1f5f2..326b0a7275b34 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala @@ -51,9 +51,6 @@ class HiveThriftServer2Suite extends FunSuite with BeforeAndAfterAll with TestUt port } - // If verbose is true, the test program will print all outputs coming from the Hive Thrift server. - val VERBOSE = Option(System.getenv("SPARK_SQL_TEST_VERBOSE")).getOrElse("false").toBoolean - Class.forName(DRIVER_NAME) override def beforeAll() { launchServer() } @@ -68,8 +65,7 @@ class HiveThriftServer2Suite extends FunSuite with BeforeAndAfterAll with TestUt val command = s"""../../sbin/start-thriftserver.sh | --master local - | --hiveconf hive.root.logger=INFO,console - | --hiveconf ${ConfVars.METASTORECONNECTURLKEY}="$jdbcUrl" + | --hiveconf ${ConfVars.METASTORECONNECTURLKEY}=$jdbcUrl | --hiveconf ${ConfVars.METASTOREWAREHOUSE}=$METASTORE_PATH | --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_BIND_HOST}=$HOST | --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_PORT}=$PORT @@ -77,12 +73,10 @@ class HiveThriftServer2Suite extends FunSuite with BeforeAndAfterAll with TestUt val pb = new ProcessBuilder(command ++ args: _*) val environment = pb.environment() - environment.put("HIVE_SERVER2_THRIFT_PORT", PORT.toString) - environment.put("HIVE_SERVER2_THRIFT_BIND_HOST", HOST) process = pb.start() inputReader = new BufferedReader(new InputStreamReader(process.getInputStream)) errorReader = new BufferedReader(new InputStreamReader(process.getErrorStream)) - waitForOutput(inputReader, "ThriftBinaryCLIService listening on") + waitForOutput(inputReader, "ThriftBinaryCLIService listening on", 300000) // Spawn a thread to read the output from the forked process. // Note that this is necessary since in some configurations, log4j could be blocked @@ -91,12 +85,8 @@ class HiveThriftServer2Suite extends FunSuite with BeforeAndAfterAll with TestUt while (true) { val stdout = readFrom(inputReader) val stderr = readFrom(errorReader) - if (VERBOSE && stdout.length > 0) { - println(stdout) - } - if (VERBOSE && stderr.length > 0) { - println(stderr) - } + print(stdout) + print(stderr) Thread.sleep(50) } } From 5095851fc284f31e7d91d192c88d1bbcf02e1d0e Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Wed, 20 Aug 2014 13:26:11 -0700 Subject: [PATCH 211/231] [SPARK-3062] [SPARK-2970] [SQL] spark-sql script ends with IOException when EventLogging is enabled #1891 was to avoid IOException when EventLogging is enabled. The solution used ShutdownHookManager but it was defined only Hadoop 2.x. Hadoop 1.x don't have ShutdownHookManager so #1891 doesn't compile on Hadoop 1.x Now, I had a compromised solution for both Hadoop 1.x and 2.x. Only for FileLogger, an unique FileSystem object is created. Author: Kousuke Saruta Closes #1970 from sarutak/SPARK-2970 and squashes the following commits: 240c91e [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into SPARK-2970 0e7b45d [Kousuke Saruta] Revert "[SPARK-2970] [SQL] spark-sql script ends with IOException when EventLogging is enabled" e1262ec [Kousuke Saruta] Modified Filelogger to use unique FileSystem instance --- .../scala/org/apache/spark/util/FileLogger.scala | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/util/FileLogger.scala b/core/src/main/scala/org/apache/spark/util/FileLogger.scala index 2e8fbf5a91ee7..ad8b79af877d8 100644 --- a/core/src/main/scala/org/apache/spark/util/FileLogger.scala +++ b/core/src/main/scala/org/apache/spark/util/FileLogger.scala @@ -52,7 +52,20 @@ private[spark] class FileLogger( override def initialValue(): SimpleDateFormat = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss") } - private val fileSystem = Utils.getHadoopFileSystem(logDir) + /** + * To avoid effects of FileSystem#close or FileSystem.closeAll called from other modules, + * create unique FileSystem instance only for FileLogger + */ + private val fileSystem = { + val conf = SparkHadoopUtil.get.newConfiguration() + val logUri = new URI(logDir) + val scheme = logUri.getScheme + if (scheme == "hdfs") { + conf.setBoolean("fs.hdfs.impl.disable.cache", true) + } + FileSystem.get(logUri, conf) + } + var fileIndex = 0 // Only used if compression is enabled From 25b01fd6bf85ac303094d9bd1d598983461bbe00 Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Wed, 20 Aug 2014 14:04:39 -0700 Subject: [PATCH 212/231] [SPARK-3149] Connection establishment information is not enough. Author: Kousuke Saruta Closes #2060 from sarutak/SPARK-3149 and squashes the following commits: 1cc89af [Kousuke Saruta] Modified log message of accepting connection (cherry picked from commit c1ba4cd6b4db22a9325eee50dc40a78593a10de1) Signed-off-by: Josh Rosen --- .../main/scala/org/apache/spark/network/ConnectionManager.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala index b3e951ded6e77..e5e1e72cd912b 100644 --- a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala +++ b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala @@ -418,7 +418,7 @@ private[spark] class ConnectionManager( newConnection.onReceive(receiveMessage) addListeners(newConnection) addConnection(newConnection) - logInfo("Accepted connection from [" + newConnection.remoteAddress.getAddress + "]") + logInfo("Accepted connection from [" + newConnection.remoteAddress + "]") } catch { // might happen in case of issues with registering with selector case e: Exception => logError("Error in accept loop", e) From beb705a4723da728be58a08039fb41fa0ffaa4a3 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Wed, 20 Aug 2014 15:01:47 -0700 Subject: [PATCH 213/231] [SPARK-2849] Handle driver configs separately in client mode In client deploy mode, the driver is launched from within `SparkSubmit`'s JVM. This means by the time we parse Spark configs from `spark-defaults.conf`, it is already too late to control certain properties of the driver's JVM. We currently ignore these configs in client mode altogether. ``` spark.driver.memory spark.driver.extraJavaOptions spark.driver.extraClassPath spark.driver.extraLibraryPath ``` This PR handles these properties before launching the driver JVM. It achieves this by spawning a separate JVM that runs a new class called `SparkSubmitDriverBootstrapper`, which spawns `SparkSubmit` as a sub-process with the appropriate classpath, library paths, java opts and memory. Author: Andrew Or Closes #1845 from andrewor14/handle-configs-bash and squashes the following commits: bed4bdf [Andrew Or] Change a few comments / messages (minor) 24dba60 [Andrew Or] Merge branch 'master' of github.com:apache/spark into handle-configs-bash 08fd788 [Andrew Or] Warn against external usages of SparkSubmitDriverBootstrapper ff34728 [Andrew Or] Minor comments 51aeb01 [Andrew Or] Filter out JVM memory in Scala rather than Bash (minor) 9a778f6 [Andrew Or] Fix PySpark: actually kill driver on termination d0f20db [Andrew Or] Don't pass empty library paths, classpath, java opts etc. a78cb26 [Andrew Or] Revert a few changes in utils.sh (minor) 9ba37e2 [Andrew Or] Don't barf when the properties file does not exist 8867a09 [Andrew Or] A few more naming things (minor) 19464ad [Andrew Or] SPARK_SUBMIT_JAVA_OPTS -> SPARK_SUBMIT_OPTS d6488f9 [Andrew Or] Merge branch 'master' of github.com:apache/spark into handle-configs-bash 1ea6bbe [Andrew Or] SparkClassLauncher -> SparkSubmitDriverBootstrapper a91ea19 [Andrew Or] Fix precedence of library paths, classpath, java opts and memory 158f813 [Andrew Or] Remove "client mode" boolean argument c84f5c8 [Andrew Or] Remove debug print statement (minor) b71f52b [Andrew Or] Revert a few more changes (minor) 7d94a8d [Andrew Or] Merge branch 'master' of github.com:apache/spark into handle-configs-bash 3a8235d [Andrew Or] Only parse the properties file if special configs exist c37e08d [Andrew Or] Revert a few more changes a396eda [Andrew Or] Nullify my own hard work to simplify bash 0effa1e [Andrew Or] Add code in Scala that handles special configs c886568 [Andrew Or] Fix lines too long + a few comments / style (minor) 7a4190a [Andrew Or] Merge branch 'master' of github.com:apache/spark into handle-configs-bash 7396be2 [Andrew Or] Explicitly comment that multi-line properties are not supported fa11ef8 [Andrew Or] Parse the properties file only if the special configs exist 371cac4 [Andrew Or] Add function prefix (minor) be99eb3 [Andrew Or] Fix tests to not include multi-line configs bd0d468 [Andrew Or] Simplify parsing config file by ignoring multi-line arguments 56ac247 [Andrew Or] Use eval and set to simplify splitting 8d4614c [Andrew Or] Merge branch 'master' of github.com:apache/spark into handle-configs-bash aeb79c7 [Andrew Or] Merge branch 'master' of github.com:apache/spark into handle-configs-bash 2732ac0 [Andrew Or] Integrate BASH tests into dev/run-tests + log error properly 8d26a5c [Andrew Or] Add tests for bash/utils.sh 4ae24c3 [Andrew Or] Fix bug: escape properly in quote_java_property b3c4cd5 [Andrew Or] Fix bug: count the number of quotes instead of detecting presence c2273fc [Andrew Or] Fix typo (minor) e793e5f [Andrew Or] Handle multi-line arguments 5d8f8c4 [Andrew Or] Merge branch 'master' of github.com:apache/spark into submit-driver-extra c7b9926 [Andrew Or] Minor changes to spark-defaults.conf.template a992ae2 [Andrew Or] Escape spark.*.extraJavaOptions correctly aabfc7e [Andrew Or] escape -> split (minor) 45a1eb9 [Andrew Or] Fix bug: escape escaped backslashes and quotes properly... 1cdc6b1 [Andrew Or] Fix bug: escape escaped double quotes properly c854859 [Andrew Or] Add small comment c13a2cb [Andrew Or] Merge branch 'master' of github.com:apache/spark into submit-driver-extra 8e552b7 [Andrew Or] Include an example of spark.*.extraJavaOptions de765c9 [Andrew Or] Print spark-class command properly a4df3c4 [Andrew Or] Move parsing and escaping logic to utils.sh dec2343 [Andrew Or] Only export variables if they exist fa2136e [Andrew Or] Escape Java options + parse java properties files properly ef12f74 [Andrew Or] Minor formatting 4ec22a1 [Andrew Or] Merge branch 'master' of github.com:apache/spark into submit-driver-extra e5cfb46 [Andrew Or] Collapse duplicate code + fix potential whitespace issues 4edcaa8 [Andrew Or] Redirect stdout to stderr for python 130f295 [Andrew Or] Handle spark.driver.memory too 98dd8e3 [Andrew Or] Add warning if properties file does not exist 8843562 [Andrew Or] Fix compilation issues... 75ee6b4 [Andrew Or] Remove accidentally added file 63ed2e9 [Andrew Or] Merge branch 'master' of github.com:apache/spark into submit-driver-extra 0025474 [Andrew Or] Revert SparkSubmit handling of --driver-* options for only cluster mode a2ab1b0 [Andrew Or] Parse spark.driver.extra* in bash 250cb95 [Andrew Or] Do not ignore spark.driver.extra* for client mode (cherry picked from commit b3ec51bfd795772ff96d18228e979a52ebc82ec4) Signed-off-by: Patrick Wendell --- bin/spark-class | 49 ++++-- bin/spark-submit | 28 +++- bin/utils.sh | 0 conf/spark-defaults.conf.template | 10 +- .../apache/spark/api/python/PythonUtils.scala | 25 --- .../api/python/PythonWorkerFactory.scala | 3 +- .../apache/spark/deploy/PythonRunner.scala | 4 +- .../org/apache/spark/deploy/SparkSubmit.scala | 17 +- .../SparkSubmitDriverBootstrapper.scala | 149 ++++++++++++++++++ .../scala/org/apache/spark/util/Utils.scala | 21 +++ 10 files changed, 250 insertions(+), 56 deletions(-) mode change 100644 => 100755 bin/utils.sh create mode 100644 core/src/main/scala/org/apache/spark/deploy/SparkSubmitDriverBootstrapper.scala diff --git a/bin/spark-class b/bin/spark-class index 3f6beca5becf0..22acf92288b3b 100755 --- a/bin/spark-class +++ b/bin/spark-class @@ -17,6 +17,8 @@ # limitations under the License. # +# NOTE: Any changes to this file must be reflected in SparkSubmitDriverBootstrapper.scala! + cygwin=false case "`uname`" in CYGWIN*) cygwin=true;; @@ -39,7 +41,7 @@ fi if [ -n "$SPARK_MEM" ]; then echo -e "Warning: SPARK_MEM is deprecated, please use a more specific config option" 1>&2 - echo -e "(e.g., spark.executor.memory or SPARK_DRIVER_MEMORY)." 1>&2 + echo -e "(e.g., spark.executor.memory or spark.driver.memory)." 1>&2 fi # Use SPARK_MEM or 512m as the default memory, to be overridden by specific options @@ -73,11 +75,17 @@ case "$1" in OUR_JAVA_MEM=${SPARK_EXECUTOR_MEMORY:-$DEFAULT_MEM} ;; - # Spark submit uses SPARK_SUBMIT_OPTS and SPARK_JAVA_OPTS - 'org.apache.spark.deploy.SparkSubmit') - OUR_JAVA_OPTS="$SPARK_JAVA_OPTS $SPARK_SUBMIT_OPTS \ - -Djava.library.path=$SPARK_SUBMIT_LIBRARY_PATH" + # Spark submit uses SPARK_JAVA_OPTS + SPARK_SUBMIT_OPTS + + # SPARK_DRIVER_MEMORY + SPARK_SUBMIT_DRIVER_MEMORY. + 'org.apache.spark.deploy.SparkSubmit') + OUR_JAVA_OPTS="$SPARK_JAVA_OPTS $SPARK_SUBMIT_OPTS" OUR_JAVA_MEM=${SPARK_DRIVER_MEMORY:-$DEFAULT_MEM} + if [ -n "$SPARK_SUBMIT_LIBRARY_PATH" ]; then + OUR_JAVA_OPTS="$OUR_JAVA_OPTS -Djava.library.path=$SPARK_SUBMIT_LIBRARY_PATH" + fi + if [ -n "$SPARK_SUBMIT_DRIVER_MEMORY" ]; then + OUR_JAVA_MEM="$SPARK_SUBMIT_DRIVER_MEMORY" + fi ;; *) @@ -101,11 +109,12 @@ fi # Set JAVA_OPTS to be able to load native libraries and to set heap size JAVA_OPTS="-XX:MaxPermSize=128m $OUR_JAVA_OPTS" JAVA_OPTS="$JAVA_OPTS -Xms$OUR_JAVA_MEM -Xmx$OUR_JAVA_MEM" + # Load extra JAVA_OPTS from conf/java-opts, if it exists if [ -e "$FWDIR/conf/java-opts" ] ; then JAVA_OPTS="$JAVA_OPTS `cat $FWDIR/conf/java-opts`" fi -export JAVA_OPTS + # Attention: when changing the way the JAVA_OPTS are assembled, the change must be reflected in CommandUtils.scala! TOOLS_DIR="$FWDIR"/tools @@ -146,10 +155,28 @@ if $cygwin; then fi export CLASSPATH -if [ "$SPARK_PRINT_LAUNCH_COMMAND" == "1" ]; then - echo -n "Spark Command: " 1>&2 - echo "$RUNNER" -cp "$CLASSPATH" $JAVA_OPTS "$@" 1>&2 - echo -e "========================================\n" 1>&2 +# In Spark submit client mode, the driver is launched in the same JVM as Spark submit itself. +# Here we must parse the properties file for relevant "spark.driver.*" configs before launching +# the driver JVM itself. Instead of handling this complexity in Bash, we launch a separate JVM +# to prepare the launch environment of this driver JVM. + +if [ -n "$SPARK_SUBMIT_BOOTSTRAP_DRIVER" ]; then + # This is used only if the properties file actually contains these special configs + # Export the environment variables needed by SparkSubmitDriverBootstrapper + export RUNNER + export CLASSPATH + export JAVA_OPTS + export OUR_JAVA_MEM + export SPARK_CLASS=1 + shift # Ignore main class (org.apache.spark.deploy.SparkSubmit) and use our own + exec "$RUNNER" org.apache.spark.deploy.SparkSubmitDriverBootstrapper "$@" +else + # Note: The format of this command is closely echoed in SparkSubmitDriverBootstrapper.scala + if [ -n "$SPARK_PRINT_LAUNCH_COMMAND" ]; then + echo -n "Spark Command: " 1>&2 + echo "$RUNNER" -cp "$CLASSPATH" $JAVA_OPTS "$@" 1>&2 + echo -e "========================================\n" 1>&2 + fi + exec "$RUNNER" -cp "$CLASSPATH" $JAVA_OPTS "$@" fi -exec "$RUNNER" -cp "$CLASSPATH" $JAVA_OPTS "$@" diff --git a/bin/spark-submit b/bin/spark-submit index 9e7cecedd0325..32c911cd0438b 100755 --- a/bin/spark-submit +++ b/bin/spark-submit @@ -17,14 +17,18 @@ # limitations under the License. # +# NOTE: Any changes in this file must be reflected in SparkClassLauncher.scala! + export SPARK_HOME="$(cd `dirname $0`/..; pwd)" ORIG_ARGS=("$@") while (($#)); do if [ "$1" = "--deploy-mode" ]; then - DEPLOY_MODE=$2 + SPARK_SUBMIT_DEPLOY_MODE=$2 + elif [ "$1" = "--properties-file" ]; then + SPARK_SUBMIT_PROPERTIES_FILE=$2 elif [ "$1" = "--driver-memory" ]; then - DRIVER_MEMORY=$2 + export SPARK_SUBMIT_DRIVER_MEMORY=$2 elif [ "$1" = "--driver-library-path" ]; then export SPARK_SUBMIT_LIBRARY_PATH=$2 elif [ "$1" = "--driver-class-path" ]; then @@ -35,10 +39,24 @@ while (($#)); do shift done -DEPLOY_MODE=${DEPLOY_MODE:-"client"} +DEFAULT_PROPERTIES_FILE="$SPARK_HOME/conf/spark-defaults.conf" +export SPARK_SUBMIT_DEPLOY_MODE=${SPARK_SUBMIT_DEPLOY_MODE:-"client"} +export SPARK_SUBMIT_PROPERTIES_FILE=${SPARK_SUBMIT_PROPERTIES_FILE:-"$DEFAULT_PROPERTIES_FILE"} + +# For client mode, the driver will be launched in the same JVM that launches +# SparkSubmit, so we may need to read the properties file for any extra class +# paths, library paths, java options and memory early on. Otherwise, it will +# be too late by the time the driver JVM has started. -if [ -n "$DRIVER_MEMORY" ] && [ $DEPLOY_MODE == "client" ]; then - export SPARK_DRIVER_MEMORY=$DRIVER_MEMORY +if [[ "$SPARK_SUBMIT_DEPLOY_MODE" == "client" && -f "$SPARK_SUBMIT_PROPERTIES_FILE" ]]; then + # Parse the properties file only if the special configs exist + contains_special_configs=$( + grep -e "spark.driver.extra*\|spark.driver.memory" "$SPARK_SUBMIT_PROPERTIES_FILE" | \ + grep -v "^[[:space:]]*#" + ) + if [ -n "$contains_special_configs" ]; then + export SPARK_SUBMIT_BOOTSTRAP_DRIVER=1 + fi fi exec $SPARK_HOME/bin/spark-class org.apache.spark.deploy.SparkSubmit "${ORIG_ARGS[@]}" diff --git a/bin/utils.sh b/bin/utils.sh old mode 100644 new mode 100755 diff --git a/conf/spark-defaults.conf.template b/conf/spark-defaults.conf.template index 2779342769c14..94427029b94d7 100644 --- a/conf/spark-defaults.conf.template +++ b/conf/spark-defaults.conf.template @@ -2,7 +2,9 @@ # This is useful for setting default environmental settings. # Example: -# spark.master spark://master:7077 -# spark.eventLog.enabled true -# spark.eventLog.dir hdfs://namenode:8021/directory -# spark.serializer org.apache.spark.serializer.KryoSerializer +# spark.master spark://master:7077 +# spark.eventLog.enabled true +# spark.eventLog.dir hdfs://namenode:8021/directory +# spark.serializer org.apache.spark.serializer.KryoSerializer +# spark.driver.memory 5g +# spark.executor.extraJavaOptions -XX:+PrintGCDetail -Dkey=value -Dnumbers="one two three" diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala index 52c70712eea3d..be5ebfa9219d3 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala @@ -40,28 +40,3 @@ private[spark] object PythonUtils { paths.filter(_ != "").mkString(File.pathSeparator) } } - - -/** - * A utility class to redirect the child process's stdout or stderr. - */ -private[spark] class RedirectThread( - in: InputStream, - out: OutputStream, - name: String) - extends Thread(name) { - - setDaemon(true) - override def run() { - scala.util.control.Exception.ignoring(classOf[IOException]) { - // FIXME: We copy the stream on the level of bytes to avoid encoding problems. - val buf = new Array[Byte](1024) - var len = in.read(buf) - while (len != -1) { - out.write(buf, 0, len) - out.flush() - len = in.read(buf) - } - } - } -} diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala index bf716a8ab025b..4c4796f6c59ba 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala @@ -17,7 +17,6 @@ package org.apache.spark.api.python -import java.lang.Runtime import java.io.{DataOutputStream, DataInputStream, InputStream, OutputStreamWriter} import java.net.{InetAddress, ServerSocket, Socket, SocketException} @@ -25,7 +24,7 @@ import scala.collection.mutable import scala.collection.JavaConversions._ import org.apache.spark._ -import org.apache.spark.util.Utils +import org.apache.spark.util.{RedirectThread, Utils} private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String, String]) extends Logging { diff --git a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala index 0d6751f3fa6d2..b66c3ba4d5fb0 100644 --- a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala @@ -22,8 +22,8 @@ import java.net.URI import scala.collection.mutable.ArrayBuffer import scala.collection.JavaConversions._ -import org.apache.spark.api.python.{PythonUtils, RedirectThread} -import org.apache.spark.util.Utils +import org.apache.spark.api.python.PythonUtils +import org.apache.spark.util.{RedirectThread, Utils} /** * A main class used by spark-submit to launch Python applications. It executes python as a diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 318509a67a36f..f8cdbc3c392b5 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -195,18 +195,21 @@ object SparkSubmit { OptionAssigner(args.jars, YARN, CLUSTER, clOption = "--addJars"), // Other options - OptionAssigner(args.driverExtraClassPath, STANDALONE | YARN, CLUSTER, - sysProp = "spark.driver.extraClassPath"), - OptionAssigner(args.driverExtraJavaOptions, STANDALONE | YARN, CLUSTER, - sysProp = "spark.driver.extraJavaOptions"), - OptionAssigner(args.driverExtraLibraryPath, STANDALONE | YARN, CLUSTER, - sysProp = "spark.driver.extraLibraryPath"), OptionAssigner(args.executorMemory, STANDALONE | MESOS | YARN, ALL_DEPLOY_MODES, sysProp = "spark.executor.memory"), OptionAssigner(args.totalExecutorCores, STANDALONE | MESOS, ALL_DEPLOY_MODES, sysProp = "spark.cores.max"), OptionAssigner(args.files, LOCAL | STANDALONE | MESOS, ALL_DEPLOY_MODES, - sysProp = "spark.files") + sysProp = "spark.files"), + + // Only process driver specific options for cluster mode here, + // because they have already been processed in bash for client mode + OptionAssigner(args.driverExtraClassPath, STANDALONE | YARN, CLUSTER, + sysProp = "spark.driver.extraClassPath"), + OptionAssigner(args.driverExtraJavaOptions, STANDALONE | YARN, CLUSTER, + sysProp = "spark.driver.extraJavaOptions"), + OptionAssigner(args.driverExtraLibraryPath, STANDALONE | YARN, CLUSTER, + sysProp = "spark.driver.extraLibraryPath") ) // In client mode, launch the application main class directly diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitDriverBootstrapper.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitDriverBootstrapper.scala new file mode 100644 index 0000000000000..af607e6a4a065 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitDriverBootstrapper.scala @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy + +import java.io.File + +import scala.collection.JavaConversions._ + +import org.apache.spark.util.{RedirectThread, Utils} + +/** + * Launch an application through Spark submit in client mode with the appropriate classpath, + * library paths, java options and memory. These properties of the JVM must be set before the + * driver JVM is launched. The sole purpose of this class is to avoid handling the complexity + * of parsing the properties file for such relevant configs in Bash. + * + * Usage: org.apache.spark.deploy.SparkSubmitDriverBootstrapper + */ +private[spark] object SparkSubmitDriverBootstrapper { + + // Note: This class depends on the behavior of `bin/spark-class` and `bin/spark-submit`. + // Any changes made there must be reflected in this file. + + def main(args: Array[String]): Unit = { + + // This should be called only from `bin/spark-class` + if (!sys.env.contains("SPARK_CLASS")) { + System.err.println("SparkSubmitDriverBootstrapper must be called from `bin/spark-class`!") + System.exit(1) + } + + val submitArgs = args + val runner = sys.env("RUNNER") + val classpath = sys.env("CLASSPATH") + val javaOpts = sys.env("JAVA_OPTS") + val defaultDriverMemory = sys.env("OUR_JAVA_MEM") + + // Spark submit specific environment variables + val deployMode = sys.env("SPARK_SUBMIT_DEPLOY_MODE") + val propertiesFile = sys.env("SPARK_SUBMIT_PROPERTIES_FILE") + val bootstrapDriver = sys.env("SPARK_SUBMIT_BOOTSTRAP_DRIVER") + val submitDriverMemory = sys.env.get("SPARK_SUBMIT_DRIVER_MEMORY") + val submitLibraryPath = sys.env.get("SPARK_SUBMIT_LIBRARY_PATH") + val submitClasspath = sys.env.get("SPARK_SUBMIT_CLASSPATH") + val submitJavaOpts = sys.env.get("SPARK_SUBMIT_OPTS") + + assume(runner != null, "RUNNER must be set") + assume(classpath != null, "CLASSPATH must be set") + assume(javaOpts != null, "JAVA_OPTS must be set") + assume(defaultDriverMemory != null, "OUR_JAVA_MEM must be set") + assume(deployMode == "client", "SPARK_SUBMIT_DEPLOY_MODE must be \"client\"!") + assume(propertiesFile != null, "SPARK_SUBMIT_PROPERTIES_FILE must be set") + assume(bootstrapDriver != null, "SPARK_SUBMIT_BOOTSTRAP_DRIVER must be set") + + // Parse the properties file for the equivalent spark.driver.* configs + val properties = SparkSubmitArguments.getPropertiesFromFile(new File(propertiesFile)).toMap + val confDriverMemory = properties.get("spark.driver.memory") + val confLibraryPath = properties.get("spark.driver.extraLibraryPath") + val confClasspath = properties.get("spark.driver.extraClassPath") + val confJavaOpts = properties.get("spark.driver.extraJavaOptions") + + // Favor Spark submit arguments over the equivalent configs in the properties file. + // Note that we do not actually use the Spark submit values for library path, classpath, + // and Java opts here, because we have already captured them in Bash. + + val newDriverMemory = submitDriverMemory + .orElse(confDriverMemory) + .getOrElse(defaultDriverMemory) + + val newLibraryPath = + if (submitLibraryPath.isDefined) { + // SPARK_SUBMIT_LIBRARY_PATH is already captured in JAVA_OPTS + "" + } else { + confLibraryPath.map("-Djava.library.path=" + _).getOrElse("") + } + + val newClasspath = + if (submitClasspath.isDefined) { + // SPARK_SUBMIT_CLASSPATH is already captured in CLASSPATH + classpath + } else { + classpath + confClasspath.map(sys.props("path.separator") + _).getOrElse("") + } + + val newJavaOpts = + if (submitJavaOpts.isDefined) { + // SPARK_SUBMIT_OPTS is already captured in JAVA_OPTS + javaOpts + } else { + javaOpts + confJavaOpts.map(" " + _).getOrElse("") + } + + val filteredJavaOpts = Utils.splitCommandString(newJavaOpts) + .filterNot(_.startsWith("-Xms")) + .filterNot(_.startsWith("-Xmx")) + + // Build up command + val command: Seq[String] = + Seq(runner) ++ + Seq("-cp", newClasspath) ++ + Seq(newLibraryPath) ++ + filteredJavaOpts ++ + Seq(s"-Xms$newDriverMemory", s"-Xmx$newDriverMemory") ++ + Seq("org.apache.spark.deploy.SparkSubmit") ++ + submitArgs + + // Print the launch command. This follows closely the format used in `bin/spark-class`. + if (sys.env.contains("SPARK_PRINT_LAUNCH_COMMAND")) { + System.err.print("Spark Command: ") + System.err.println(command.mkString(" ")) + System.err.println("========================================\n") + } + + // Start the driver JVM + val filteredCommand = command.filter(_.nonEmpty) + val builder = new ProcessBuilder(filteredCommand) + val process = builder.start() + + // Redirect stdin, stdout, and stderr to/from the child JVM + val stdinThread = new RedirectThread(System.in, process.getOutputStream, "redirect stdin") + val stdoutThread = new RedirectThread(process.getInputStream, System.out, "redirect stdout") + val stderrThread = new RedirectThread(process.getErrorStream, System.err, "redirect stderr") + stdinThread.start() + stdoutThread.start() + stderrThread.start() + + // Terminate on broken pipe, which signals that the parent process has exited. This is + // important for the PySpark shell, where Spark submit itself is a python subprocess. + stdinThread.join() + process.destroy() + } + +} diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index d6d74ce269219..69a84a3604a52 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -1480,3 +1480,24 @@ private[spark] object Utils extends Logging { } } + +/** + * A utility class to redirect the child process's stdout or stderr. + */ +private[spark] class RedirectThread(in: InputStream, out: OutputStream, name: String) + extends Thread(name) { + + setDaemon(true) + override def run() { + scala.util.control.Exception.ignoring(classOf[IOException]) { + // FIXME: We copy the stream on the level of bytes to avoid encoding problems. + val buf = new Array[Byte](1024) + var len = in.read(buf) + while (len != -1) { + out.write(buf, 0, len) + out.flush() + len = in.read(buf) + } + } + } +} From 311831db71b742a0472d67a1127c818e5ba0a505 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Wed, 20 Aug 2014 15:51:14 -0700 Subject: [PATCH 214/231] [SPARK-2967][SQL] Fix sort based shuffle for spark sql. Add explicit row copies when sort based shuffle is on. Author: Michael Armbrust Closes #2066 from marmbrus/sortShuffle and squashes the following commits: fcd7bb2 [Michael Armbrust] Fix sort based shuffle for spark sql. (cherry picked from commit a2e658dcdab614058eefcf50ae2d419ece9b1fe7) Signed-off-by: Michael Armbrust --- .../apache/spark/sql/execution/Exchange.scala | 30 ++++++++++++++----- 1 file changed, 23 insertions(+), 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 77dc2ad733215..09c34b7059fc3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -18,7 +18,8 @@ package org.apache.spark.sql.execution import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.{HashPartitioner, RangePartitioner, SparkConf} +import org.apache.spark.shuffle.sort.SortShuffleManager +import org.apache.spark.{SparkEnv, HashPartitioner, RangePartitioner, SparkConf} import org.apache.spark.rdd.ShuffledRDD import org.apache.spark.sql.{SQLContext, Row} import org.apache.spark.sql.catalyst.errors.attachTree @@ -37,6 +38,9 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una def output = child.output + /** We must copy rows when sort based shuffle is on */ + protected def sortBasedShuffleOn = SparkEnv.get.shuffleManager.isInstanceOf[SortShuffleManager] + def execute() = attachTree(this , "execute") { newPartitioning match { case HashPartitioning(expressions, numPartitions) => @@ -45,8 +49,12 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una @transient val hashExpressions = newMutableProjection(expressions, child.output)() - val mutablePair = new MutablePair[Row, Row]() - iter.map(r => mutablePair.update(hashExpressions(r), r)) + if (sortBasedShuffleOn) { + iter.map(r => (hashExpressions(r), r.copy())) + } else { + val mutablePair = new MutablePair[Row, Row]() + iter.map(r => mutablePair.update(hashExpressions(r), r)) + } } val part = new HashPartitioner(numPartitions) val shuffled = new ShuffledRDD[Row, Row, Row](rdd, part) @@ -58,8 +66,12 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una implicit val ordering = new RowOrdering(sortingExpressions, child.output) val rdd = child.execute().mapPartitions { iter => - val mutablePair = new MutablePair[Row, Null](null, null) - iter.map(row => mutablePair.update(row, null)) + if (sortBasedShuffleOn) { + iter.map(row => (row.copy(), null)) + } else { + val mutablePair = new MutablePair[Row, Null](null, null) + iter.map(row => mutablePair.update(row, null)) + } } val part = new RangePartitioner(numPartitions, rdd, ascending = true) val shuffled = new ShuffledRDD[Row, Null, Null](rdd, part) @@ -69,8 +81,12 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una case SinglePartition => val rdd = child.execute().mapPartitions { iter => - val mutablePair = new MutablePair[Null, Row]() - iter.map(r => mutablePair.update(null, r)) + if (sortBasedShuffleOn) { + iter.map(r => (null, r.copy())) + } else { + val mutablePair = new MutablePair[Null, Row]() + iter.map(r => mutablePair.update(null, r)) + } } val partitioner = new HashPartitioner(1) val shuffled = new ShuffledRDD[Null, Row, Row](rdd, partitioner) From 5f72d7bcf553a0216c4849e1918ed74b96d2224a Mon Sep 17 00:00:00 2001 From: wangfei Date: Wed, 20 Aug 2014 16:00:46 -0700 Subject: [PATCH 215/231] SPARK_LOGFILE and SPARK_ROOT_LOGGER no longer need in spark-daemon.sh Author: wangfei Closes #2057 from scwf/patch-7 and squashes the following commits: 1b7b9a5 [wangfei] SPARK_LOGFILE and SPARK_ROOT_LOGGER no longer need in spark-daemon.sh (cherry picked from commit a1e8b1bc973bc0517681c09e5a5a475c0f395d31) Signed-off-by: Andrew Or --- sbin/spark-daemon.sh | 2 -- 1 file changed, 2 deletions(-) diff --git a/sbin/spark-daemon.sh b/sbin/spark-daemon.sh index 323f675b17848..9032f23ea8eff 100755 --- a/sbin/spark-daemon.sh +++ b/sbin/spark-daemon.sh @@ -113,8 +113,6 @@ if [ "$SPARK_PID_DIR" = "" ]; then fi # some variables -export SPARK_LOGFILE=spark-$SPARK_IDENT_STRING-$command-$instance-$HOSTNAME.log -export SPARK_ROOT_LOGGER="INFO,DRFA" log=$SPARK_LOG_DIR/spark-$SPARK_IDENT_STRING-$command-$instance-$HOSTNAME.out pid=$SPARK_PID_DIR/spark-$SPARK_IDENT_STRING-$command-$instance.pid From 64e136a640a9ccbde74f7c754b375d175f1991d4 Mon Sep 17 00:00:00 2001 From: Alex Liu Date: Wed, 20 Aug 2014 16:14:06 -0700 Subject: [PATCH 216/231] [SPARK-2846][SQL] Add configureInputJobPropertiesForStorageHandler to initialization of job conf ...al job conf Author: Alex Liu Closes #1927 from alexliu68/SPARK-SQL-2846 and squashes the following commits: e4bdc4c [Alex Liu] SPARK-SQL-2846 add configureInputJobPropertiesForStorageHandler to initial job conf (cherry picked from commit d9e94146a6e65be110a62e3bd0351148912a41d1) Signed-off-by: Michael Armbrust --- .../src/main/scala/org/apache/spark/sql/hive/TableReader.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala index 82c88280d7754..329f80cad471e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala @@ -22,7 +22,7 @@ import org.apache.hadoop.fs.{Path, PathFilter} import org.apache.hadoop.hive.metastore.api.hive_metastoreConstants._ import org.apache.hadoop.hive.ql.exec.Utilities import org.apache.hadoop.hive.ql.metadata.{Partition => HivePartition, Table => HiveTable} -import org.apache.hadoop.hive.ql.plan.TableDesc +import org.apache.hadoop.hive.ql.plan.{PlanUtils, TableDesc} import org.apache.hadoop.hive.serde2.Deserializer import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector @@ -249,6 +249,7 @@ private[hive] object HadoopTableReader extends HiveInspectors { def initializeLocalJobConfFunc(path: String, tableDesc: TableDesc)(jobConf: JobConf) { FileInputFormat.setInputPaths(jobConf, path) if (tableDesc != null) { + PlanUtils.configureInputJobPropertiesForStorageHandler(tableDesc) Utilities.copyTableJobPropertiesToConf(tableDesc, jobConf) } val bufferSize = System.getProperty("spark.buffer.size", "65536") From 2c1683efeabe461744509548341b8f93d8b22558 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Mon, 18 Aug 2014 13:25:30 -0700 Subject: [PATCH 217/231] [SPARK-2169] Don't copy appName / basePath everywhere. Instead of keeping copies in all pages, just reference the values kept in the base SparkUI instance (by making them available via getters). Author: Marcelo Vanzin Closes #1252 from vanzin/SPARK-2169 and squashes the following commits: 4412fc6 [Marcelo Vanzin] Simplify UIUtils.headerSparkPage signature. 4e5d35a [Marcelo Vanzin] [SPARK-2169] Don't copy appName / basePath everywhere. --- .../apache/spark/deploy/master/Master.scala | 2 +- .../scala/org/apache/spark/ui/SparkUI.scala | 9 +++++++++ .../scala/org/apache/spark/ui/UIUtils.scala | 12 +++++------- .../scala/org/apache/spark/ui/WebUI.scala | 3 +++ .../apache/spark/ui/env/EnvironmentPage.scala | 4 +--- .../apache/spark/ui/env/EnvironmentTab.scala | 4 +--- .../apache/spark/ui/exec/ExecutorsPage.scala | 5 +---- .../apache/spark/ui/exec/ExecutorsTab.scala | 6 ++---- .../spark/ui/jobs/JobProgressPage.scala | 4 +--- .../apache/spark/ui/jobs/JobProgressTab.scala | 7 +++---- .../org/apache/spark/ui/jobs/PoolPage.scala | 5 +---- .../org/apache/spark/ui/jobs/PoolTable.scala | 7 +++---- .../org/apache/spark/ui/jobs/StagePage.scala | 8 ++------ .../org/apache/spark/ui/jobs/StageTable.scala | 19 ++++++++++--------- .../org/apache/spark/ui/storage/RDDPage.scala | 8 ++------ .../apache/spark/ui/storage/StoragePage.scala | 6 ++---- .../apache/spark/ui/storage/StorageTab.scala | 4 +--- .../spark/streaming/ui/StreamingPage.scala | 3 +-- .../spark/streaming/ui/StreamingTab.scala | 6 ++---- 19 files changed, 51 insertions(+), 71 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index cfa2c028a807b..5017273e87c07 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -697,7 +697,7 @@ private[spark] class Master( appIdToUI(app.id) = ui webUi.attachSparkUI(ui) // Application UI is successfully rebuilt, so link the Master UI to it - app.desc.appUiUrl = ui.basePath + app.desc.appUiUrl = ui.getBasePath true } catch { case e: Exception => diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala index 6c788a37dc70b..cccd59d122a92 100644 --- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala @@ -76,6 +76,8 @@ private[spark] class SparkUI( } } + def getAppName = appName + /** Set the app name for this UI. */ def setAppName(name: String) { appName = name @@ -100,6 +102,13 @@ private[spark] class SparkUI( private[spark] def appUIAddress = s"http://$appUIHostPort" } +private[spark] abstract class SparkUITab(parent: SparkUI, prefix: String) + extends WebUITab(parent, prefix) { + + def appName: String = parent.getAppName + +} + private[spark] object SparkUI { val DEFAULT_PORT = 4040 val STATIC_RESOURCE_DIR = "org/apache/spark/ui/static" diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index 715cc2f4df8dd..bee6dad3387e5 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -163,17 +163,15 @@ private[spark] object UIUtils extends Logging { /** Returns a spark page with correctly formatted headers */ def headerSparkPage( - content: => Seq[Node], - basePath: String, - appName: String, title: String, - tabs: Seq[WebUITab], - activeTab: WebUITab, + content: => Seq[Node], + activeTab: SparkUITab, refreshInterval: Option[Int] = None): Seq[Node] = { - val header = tabs.map { tab => + val appName = activeTab.appName + val header = activeTab.headerTabs.map { tab =>
  • - {tab.name} + {tab.name}
  • } diff --git a/core/src/main/scala/org/apache/spark/ui/WebUI.scala b/core/src/main/scala/org/apache/spark/ui/WebUI.scala index 5f52f95088007..5d88ca403a674 100644 --- a/core/src/main/scala/org/apache/spark/ui/WebUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/WebUI.scala @@ -50,6 +50,7 @@ private[spark] abstract class WebUI( protected val publicHostName = Option(System.getenv("SPARK_PUBLIC_DNS")).getOrElse(localHostName) private val className = Utils.getFormattedClassName(this) + def getBasePath: String = basePath def getTabs: Seq[WebUITab] = tabs.toSeq def getHandlers: Seq[ServletContextHandler] = handlers.toSeq def getSecurityManager: SecurityManager = securityManager @@ -135,6 +136,8 @@ private[spark] abstract class WebUITab(parent: WebUI, val prefix: String) { /** Get a list of header tabs from the parent UI. */ def headerTabs: Seq[WebUITab] = parent.getTabs + + def basePath: String = parent.getBasePath } diff --git a/core/src/main/scala/org/apache/spark/ui/env/EnvironmentPage.scala b/core/src/main/scala/org/apache/spark/ui/env/EnvironmentPage.scala index b347eb1b83c1f..f0a1174a71d34 100644 --- a/core/src/main/scala/org/apache/spark/ui/env/EnvironmentPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/env/EnvironmentPage.scala @@ -24,8 +24,6 @@ import scala.xml.Node import org.apache.spark.ui.{UIUtils, WebUIPage} private[ui] class EnvironmentPage(parent: EnvironmentTab) extends WebUIPage("") { - private val appName = parent.appName - private val basePath = parent.basePath private val listener = parent.listener def render(request: HttpServletRequest): Seq[Node] = { @@ -45,7 +43,7 @@ private[ui] class EnvironmentPage(parent: EnvironmentTab) extends WebUIPage("")

    Classpath Entries

    {classpathEntriesTable} - UIUtils.headerSparkPage(content, basePath, appName, "Environment", parent.headerTabs, parent) + UIUtils.headerSparkPage("Environment", content, parent) } private def propertyHeader = Seq("Name", "Value") diff --git a/core/src/main/scala/org/apache/spark/ui/env/EnvironmentTab.scala b/core/src/main/scala/org/apache/spark/ui/env/EnvironmentTab.scala index bbbe55ecf44a1..0d158fbe638d3 100644 --- a/core/src/main/scala/org/apache/spark/ui/env/EnvironmentTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/env/EnvironmentTab.scala @@ -21,9 +21,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.scheduler._ import org.apache.spark.ui._ -private[ui] class EnvironmentTab(parent: SparkUI) extends WebUITab(parent, "environment") { - val appName = parent.appName - val basePath = parent.basePath +private[ui] class EnvironmentTab(parent: SparkUI) extends SparkUITab(parent, "environment") { val listener = new EnvironmentListener attachPage(new EnvironmentPage(this)) diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala index b814b0e6b8509..02df4e8fe61af 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala @@ -43,8 +43,6 @@ private case class ExecutorSummaryInfo( maxMemory: Long) private[ui] class ExecutorsPage(parent: ExecutorsTab) extends WebUIPage("") { - private val appName = parent.appName - private val basePath = parent.basePath private val listener = parent.listener def render(request: HttpServletRequest): Seq[Node] = { @@ -101,8 +99,7 @@ private[ui] class ExecutorsPage(parent: ExecutorsTab) extends WebUIPage("") {
    ; - UIUtils.headerSparkPage(content, basePath, appName, "Executors (" + execInfo.size + ")", - parent.headerTabs, parent) + UIUtils.headerSparkPage("Executors (" + execInfo.size + ")", content, parent) } /** Render an HTML row representing an executor */ diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala index 5c2d1d1fe75d3..61eb111cd9100 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala @@ -23,11 +23,9 @@ import org.apache.spark.ExceptionFailure import org.apache.spark.annotation.DeveloperApi import org.apache.spark.scheduler._ import org.apache.spark.storage.StorageStatusListener -import org.apache.spark.ui.{SparkUI, WebUITab} +import org.apache.spark.ui.{SparkUI, SparkUITab} -private[ui] class ExecutorsTab(parent: SparkUI) extends WebUITab(parent, "executors") { - val appName = parent.appName - val basePath = parent.basePath +private[ui] class ExecutorsTab(parent: SparkUI) extends SparkUITab(parent, "executors") { val listener = new ExecutorsListener(parent.storageStatusListener) attachPage(new ExecutorsPage(this)) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressPage.scala index 0da62892118d4..a82f71ed08475 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressPage.scala @@ -26,8 +26,6 @@ import org.apache.spark.ui.{WebUIPage, UIUtils} /** Page showing list of all ongoing and recently finished stages and pools */ private[ui] class JobProgressPage(parent: JobProgressTab) extends WebUIPage("") { - private val appName = parent.appName - private val basePath = parent.basePath private val live = parent.live private val sc = parent.sc private val listener = parent.listener @@ -94,7 +92,7 @@ private[ui] class JobProgressPage(parent: JobProgressTab) extends WebUIPage("")

    Failed Stages ({failedStages.size})

    ++ failedStagesTable.toNodeSeq - UIUtils.headerSparkPage(content, basePath, appName, "Spark Stages", parent.headerTabs, parent) + UIUtils.headerSparkPage("Spark Stages", content, parent) } } } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressTab.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressTab.scala index 8a01ec80c9dd6..c16542c9db30f 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressTab.scala @@ -21,12 +21,10 @@ import javax.servlet.http.HttpServletRequest import org.apache.spark.SparkConf import org.apache.spark.scheduler.SchedulingMode -import org.apache.spark.ui.{SparkUI, WebUITab} +import org.apache.spark.ui.{SparkUI, SparkUITab} /** Web UI showing progress status of all jobs in the given SparkContext. */ -private[ui] class JobProgressTab(parent: SparkUI) extends WebUITab(parent, "stages") { - val appName = parent.appName - val basePath = parent.basePath +private[ui] class JobProgressTab(parent: SparkUI) extends SparkUITab(parent, "stages") { val live = parent.live val sc = parent.sc val conf = if (live) sc.conf else new SparkConf @@ -53,4 +51,5 @@ private[ui] class JobProgressTab(parent: SparkUI) extends WebUITab(parent, "stag Thread.sleep(100) } } + } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala index 0a2bf31833d2b..7a6c7d1a497ed 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala @@ -26,8 +26,6 @@ import org.apache.spark.ui.{WebUIPage, UIUtils} /** Page showing specific pool details */ private[ui] class PoolPage(parent: JobProgressTab) extends WebUIPage("pool") { - private val appName = parent.appName - private val basePath = parent.basePath private val live = parent.live private val sc = parent.sc private val listener = parent.listener @@ -51,8 +49,7 @@ private[ui] class PoolPage(parent: JobProgressTab) extends WebUIPage("pool") {

    Summary

    ++ poolTable.toNodeSeq ++

    {activeStages.size} Active Stages

    ++ activeStagesTable.toNodeSeq - UIUtils.headerSparkPage(content, basePath, appName, "Fair Scheduler Pool: " + poolName, - parent.headerTabs, parent) + UIUtils.headerSparkPage("Fair Scheduler Pool: " + poolName, content, parent) } } } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala index f4b68f241966d..64178e1e33d41 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala @@ -25,7 +25,6 @@ import org.apache.spark.ui.UIUtils /** Table showing list of pools */ private[ui] class PoolTable(pools: Seq[Schedulable], parent: JobProgressTab) { - private val basePath = parent.basePath private val listener = parent.listener def toNodeSeq: Seq[Node] = { @@ -59,11 +58,11 @@ private[ui] class PoolTable(pools: Seq[Schedulable], parent: JobProgressTab) { case Some(stages) => stages.size case None => 0 } + val href = "%s/stages/pool?poolname=%s" + .format(UIUtils.prependBaseUri(parent.basePath), p.name) - - {p.name} - + {p.name} {p.minShare} {p.weight} diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 8bc1ba758cf77..d4eb02722ad12 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -29,8 +29,6 @@ import org.apache.spark.scheduler.AccumulableInfo /** Page showing statistics and task list for a given stage */ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { - private val appName = parent.appName - private val basePath = parent.basePath private val listener = parent.listener def render(request: HttpServletRequest): Seq[Node] = { @@ -44,8 +42,7 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") {

    Summary Metrics

    No tasks have started yet

    Tasks

    No tasks have started yet
    - return UIUtils.headerSparkPage(content, basePath, appName, - "Details for Stage %s".format(stageId), parent.headerTabs, parent) + return UIUtils.headerSparkPage("Details for Stage %s".format(stageId), content, parent) } val stageData = stageDataOption.get @@ -227,8 +224,7 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { maybeAccumulableTable ++

    Tasks

    ++ taskTable - UIUtils.headerSparkPage(content, basePath, appName, "Details for Stage %d".format(stageId), - parent.headerTabs, parent) + UIUtils.headerSparkPage("Details for Stage %d".format(stageId), content, parent) } } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala index 15998404ed612..16ad0df45aa0d 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala @@ -32,7 +32,6 @@ private[ui] class StageTableBase( parent: JobProgressTab, killEnabled: Boolean = false) { - private val basePath = parent.basePath private val listener = parent.listener protected def isFairScheduler = parent.isFairScheduler @@ -88,17 +87,19 @@ private[ui] class StageTableBase( private def makeDescription(s: StageInfo): Seq[Node] = { // scalastyle:off val killLink = if (killEnabled) { + val killLinkUri = "%s/stages/stage/kill?id=%s&terminate=true" + .format(UIUtils.prependBaseUri(parent.basePath), s.stageId) + val confirm = "return window.confirm('Are you sure you want to kill stage %s ?');" + .format(s.stageId) - (kill) + (kill) } // scalastyle:on - val nameLink = - - {s.name} - + val nameLinkUri ="%s/stages/stage?id=%s" + .format(UIUtils.prependBaseUri(parent.basePath), s.stageId) + val nameLink = {s.name} val cachedRddInfos = s.rddInfos.filter(_.numCachedPartitions > 0) val details = if (s.details.nonEmpty) { @@ -111,7 +112,7 @@ private[ui] class StageTableBase( Text("RDD: ") ++ // scalastyle:off cachedRddInfos.map { i => - {i.name} + {i.name} } // scalastyle:on }} @@ -157,7 +158,7 @@ private[ui] class StageTableBase( {if (isFairScheduler) { + .format(UIUtils.prependBaseUri(parent.basePath), stageData.schedulingPool)}> {stageData.schedulingPool} diff --git a/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala index 84ac53da47552..8a0075ae8daf7 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala @@ -27,8 +27,6 @@ import org.apache.spark.util.Utils /** Page showing storage details for a given RDD */ private[ui] class RDDPage(parent: StorageTab) extends WebUIPage("rdd") { - private val appName = parent.appName - private val basePath = parent.basePath private val listener = parent.listener def render(request: HttpServletRequest): Seq[Node] = { @@ -36,8 +34,7 @@ private[ui] class RDDPage(parent: StorageTab) extends WebUIPage("rdd") { val storageStatusList = listener.storageStatusList val rddInfo = listener.rddInfoList.find(_.id == rddId).getOrElse { // Rather than crashing, render an "RDD Not Found" page - return UIUtils.headerSparkPage(Seq[Node](), basePath, appName, "RDD Not Found", - parent.headerTabs, parent) + return UIUtils.headerSparkPage("RDD Not Found", Seq[Node](), parent) } // Worker table @@ -96,8 +93,7 @@ private[ui] class RDDPage(parent: StorageTab) extends WebUIPage("rdd") {
    ; - UIUtils.headerSparkPage(content, basePath, appName, "RDD Storage Info for " + rddInfo.name, - parent.headerTabs, parent) + UIUtils.headerSparkPage("RDD Storage Info for " + rddInfo.name, content, parent) } /** Header fields for the worker table */ diff --git a/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala b/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala index 9813d9330ac7f..716591c9ed449 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala @@ -27,14 +27,12 @@ import org.apache.spark.util.Utils /** Page showing list of RDD's currently stored in the cluster */ private[ui] class StoragePage(parent: StorageTab) extends WebUIPage("") { - private val appName = parent.appName - private val basePath = parent.basePath private val listener = parent.listener def render(request: HttpServletRequest): Seq[Node] = { val rdds = listener.rddInfoList val content = UIUtils.listingTable(rddHeader, rddRow, rdds) - UIUtils.headerSparkPage(content, basePath, appName, "Storage ", parent.headerTabs, parent) + UIUtils.headerSparkPage("Storage", content, parent) } /** Header fields for the RDD table */ @@ -52,7 +50,7 @@ private[ui] class StoragePage(parent: StorageTab) extends WebUIPage("") { // scalastyle:off - + {rdd.name} diff --git a/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala b/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala index 5f6740d495521..67f72a94f0269 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala @@ -25,9 +25,7 @@ import org.apache.spark.scheduler._ import org.apache.spark.storage._ /** Web UI showing storage status of all RDD's in the given SparkContext. */ -private[ui] class StorageTab(parent: SparkUI) extends WebUITab(parent, "storage") { - val appName = parent.appName - val basePath = parent.basePath +private[ui] class StorageTab(parent: SparkUI) extends SparkUITab(parent, "storage") { val listener = new StorageListener(parent.storageStatusListener) attachPage(new StoragePage(this)) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala index 451b23e01c995..1353e487c72cf 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala @@ -42,8 +42,7 @@ private[ui] class StreamingPage(parent: StreamingTab)

    Statistics over last {listener.retainedCompletedBatches.size} processed batches

    ++ generateReceiverStats() ++ generateBatchStatsTable() - UIUtils.headerSparkPage( - content, parent.basePath, parent.appName, "Streaming", parent.headerTabs, parent, Some(5000)) + UIUtils.headerSparkPage("Streaming", content, parent, Some(5000)) } /** Generate basic stats of the streaming program */ diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingTab.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingTab.scala index 51448d15c6516..34ac254f337eb 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingTab.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingTab.scala @@ -19,15 +19,13 @@ package org.apache.spark.streaming.ui import org.apache.spark.Logging import org.apache.spark.streaming.StreamingContext -import org.apache.spark.ui.WebUITab +import org.apache.spark.ui.SparkUITab /** Spark Web UI tab that shows statistics of a streaming job */ private[spark] class StreamingTab(ssc: StreamingContext) - extends WebUITab(ssc.sc.ui, "streaming") with Logging { + extends SparkUITab(ssc.sc.ui, "streaming") with Logging { val parent = ssc.sc.ui - val appName = parent.appName - val basePath = parent.basePath val listener = new StreamingJobProgressListener(ssc) ssc.addStreamingListener(listener) From dc05282bafce8e11de35d7d2f489a8b50a91661d Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 20 Aug 2014 15:37:27 -0700 Subject: [PATCH 218/231] [SPARK-2298] Encode stage attempt in SparkListener & UI. Simple way to reproduce this in the UI: ```scala val f = new java.io.File("/tmp/test") f.delete() sc.parallelize(1 to 2, 2).map(x => (x,x )).repartition(3).mapPartitionsWithContext { case (context, iter) => if (context.partitionId == 0) { val f = new java.io.File("/tmp/test") if (!f.exists) { f.mkdir() System.exit(0); } } iter }.count() ``` Author: Reynold Xin Closes #1545 from rxin/stage-attempt and squashes the following commits: 3ee1d2a [Reynold Xin] - Rename attempt to retry in UI. - Properly report stage failure in FetchFailed. 40a6bd5 [Reynold Xin] Updated test suites. c414c36 [Reynold Xin] Fixed the hanging in JobCancellationSuite. b3e2eed [Reynold Xin] Oops previous code didn't compile. 0f36075 [Reynold Xin] Mark unknown stage attempt with id -1 and drop that in JobProgressListener. 6c08b07 [Reynold Xin] Addressed code review feedback. 4e5faa2 [Reynold Xin] [SPARK-2298] Encode stage attempt in SparkListener & UI. --- .../apache/spark/scheduler/DAGScheduler.scala | 77 +-- .../spark/scheduler/SparkListener.scala | 11 +- .../org/apache/spark/scheduler/Stage.scala | 8 +- .../apache/spark/scheduler/StageInfo.scala | 11 +- .../spark/scheduler/TaskSchedulerImpl.scala | 8 +- .../org/apache/spark/scheduler/TaskSet.scala | 4 - .../apache/spark/ui/jobs/ExecutorTable.scala | 6 +- .../spark/ui/jobs/JobProgressListener.scala | 40 +- .../org/apache/spark/ui/jobs/StagePage.scala | 11 +- .../org/apache/spark/ui/jobs/StageTable.scala | 14 +- .../org/apache/spark/util/JsonProtocol.scala | 12 +- .../storage/StorageStatusListenerSuite.scala | 17 +- .../ui/jobs/JobProgressListenerSuite.scala | 68 +-- .../spark/ui/storage/StorageTabSuite.scala | 16 +- .../apache/spark/util/JsonProtocolSuite.scala | 476 ++++++++++++++---- 15 files changed, 555 insertions(+), 224 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index b86cfbfa48fbe..34131984570e4 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -164,7 +164,7 @@ class DAGScheduler( */ def executorHeartbeatReceived( execId: String, - taskMetrics: Array[(Long, Int, TaskMetrics)], // (taskId, stageId, metrics) + taskMetrics: Array[(Long, Int, Int, TaskMetrics)], // (taskId, stageId, stateAttempt, metrics) blockManagerId: BlockManagerId): Boolean = { listenerBus.post(SparkListenerExecutorMetricsUpdate(execId, taskMetrics)) implicit val timeout = Timeout(600 seconds) @@ -677,7 +677,10 @@ class DAGScheduler( } private[scheduler] def handleBeginEvent(task: Task[_], taskInfo: TaskInfo) { - listenerBus.post(SparkListenerTaskStart(task.stageId, taskInfo)) + // Note that there is a chance that this task is launched after the stage is cancelled. + // In that case, we wouldn't have the stage anymore in stageIdToStage. + val stageAttemptId = stageIdToStage.get(task.stageId).map(_.latestInfo.attemptId).getOrElse(-1) + listenerBus.post(SparkListenerTaskStart(task.stageId, stageAttemptId, taskInfo)) submitWaitingStages() } @@ -695,8 +698,8 @@ class DAGScheduler( // is in the process of getting stopped. val stageFailedMessage = "Stage cancelled because SparkContext was shut down" runningStages.foreach { stage => - stage.info.stageFailed(stageFailedMessage) - listenerBus.post(SparkListenerStageCompleted(stage.info)) + stage.latestInfo.stageFailed(stageFailedMessage) + listenerBus.post(SparkListenerStageCompleted(stage.latestInfo)) } listenerBus.post(SparkListenerJobEnd(job.jobId, JobFailed(error))) } @@ -781,7 +784,16 @@ class DAGScheduler( logDebug("submitMissingTasks(" + stage + ")") // Get our pending tasks and remember them in our pendingTasks entry stage.pendingTasks.clear() - var tasks = ArrayBuffer[Task[_]]() + + // First figure out the indexes of partition ids to compute. + val partitionsToCompute: Seq[Int] = { + if (stage.isShuffleMap) { + (0 until stage.numPartitions).filter(id => stage.outputLocs(id) == Nil) + } else { + val job = stage.resultOfJob.get + (0 until job.numPartitions).filter(id => !job.finished(id)) + } + } val properties = if (jobIdToActiveJob.contains(jobId)) { jobIdToActiveJob(stage.jobId).properties @@ -795,7 +807,8 @@ class DAGScheduler( // serializable. If tasks are not serializable, a SparkListenerStageCompleted event // will be posted, which should always come after a corresponding SparkListenerStageSubmitted // event. - listenerBus.post(SparkListenerStageSubmitted(stage.info, properties)) + stage.latestInfo = StageInfo.fromStage(stage, Some(partitionsToCompute.size)) + listenerBus.post(SparkListenerStageSubmitted(stage.latestInfo, properties)) // TODO: Maybe we can keep the taskBinary in Stage to avoid serializing it multiple times. // Broadcasted binary for the task, used to dispatch tasks to executors. Note that we broadcast @@ -826,20 +839,19 @@ class DAGScheduler( return } - if (stage.isShuffleMap) { - for (p <- 0 until stage.numPartitions if stage.outputLocs(p) == Nil) { - val locs = getPreferredLocs(stage.rdd, p) - val part = stage.rdd.partitions(p) - tasks += new ShuffleMapTask(stage.id, taskBinary, part, locs) + val tasks: Seq[Task[_]] = if (stage.isShuffleMap) { + partitionsToCompute.map { id => + val locs = getPreferredLocs(stage.rdd, id) + val part = stage.rdd.partitions(id) + new ShuffleMapTask(stage.id, taskBinary, part, locs) } } else { - // This is a final stage; figure out its job's missing partitions val job = stage.resultOfJob.get - for (id <- 0 until job.numPartitions if !job.finished(id)) { + partitionsToCompute.map { id => val p: Int = job.partitions(id) val part = stage.rdd.partitions(p) val locs = getPreferredLocs(stage.rdd, p) - tasks += new ResultTask(stage.id, taskBinary, part, locs, id) + new ResultTask(stage.id, taskBinary, part, locs, id) } } @@ -869,11 +881,11 @@ class DAGScheduler( logDebug("New pending tasks: " + stage.pendingTasks) taskScheduler.submitTasks( new TaskSet(tasks.toArray, stage.id, stage.newAttemptId(), stage.jobId, properties)) - stage.info.submissionTime = Some(clock.getTime()) + stage.latestInfo.submissionTime = Some(clock.getTime()) } else { // Because we posted SparkListenerStageSubmitted earlier, we should post // SparkListenerStageCompleted here in case there are no tasks to run. - listenerBus.post(SparkListenerStageCompleted(stage.info)) + listenerBus.post(SparkListenerStageCompleted(stage.latestInfo)) logDebug("Stage " + stage + " is actually done; %b %d %d".format( stage.isAvailable, stage.numAvailableOutputs, stage.numPartitions)) runningStages -= stage @@ -892,8 +904,9 @@ class DAGScheduler( // The success case is dealt with separately below, since we need to compute accumulator // updates before posting. if (event.reason != Success) { - listenerBus.post(SparkListenerTaskEnd(stageId, taskType, event.reason, event.taskInfo, - event.taskMetrics)) + val attemptId = stageIdToStage.get(task.stageId).map(_.latestInfo.attemptId).getOrElse(-1) + listenerBus.post(SparkListenerTaskEnd(stageId, attemptId, taskType, event.reason, + event.taskInfo, event.taskMetrics)) } if (!stageIdToStage.contains(task.stageId)) { @@ -902,14 +915,19 @@ class DAGScheduler( } val stage = stageIdToStage(task.stageId) - def markStageAsFinished(stage: Stage) = { - val serviceTime = stage.info.submissionTime match { + def markStageAsFinished(stage: Stage, errorMessage: Option[String] = None) = { + val serviceTime = stage.latestInfo.submissionTime match { case Some(t) => "%.03f".format((clock.getTime() - t) / 1000.0) case _ => "Unknown" } - logInfo("%s (%s) finished in %s s".format(stage, stage.name, serviceTime)) - stage.info.completionTime = Some(clock.getTime()) - listenerBus.post(SparkListenerStageCompleted(stage.info)) + if (errorMessage.isEmpty) { + logInfo("%s (%s) finished in %s s".format(stage, stage.name, serviceTime)) + stage.latestInfo.completionTime = Some(clock.getTime()) + } else { + stage.latestInfo.stageFailed(errorMessage.get) + logInfo("%s (%s) failed in %s s".format(stage, stage.name, serviceTime)) + } + listenerBus.post(SparkListenerStageCompleted(stage.latestInfo)) runningStages -= stage } event.reason match { @@ -924,7 +942,7 @@ class DAGScheduler( val name = acc.name.get val stringPartialValue = Accumulators.stringifyPartialValue(partialValue) val stringValue = Accumulators.stringifyValue(acc.value) - stage.info.accumulables(id) = AccumulableInfo(id, name, stringValue) + stage.latestInfo.accumulables(id) = AccumulableInfo(id, name, stringValue) event.taskInfo.accumulables += AccumulableInfo(id, name, Some(stringPartialValue), stringValue) } @@ -935,8 +953,8 @@ class DAGScheduler( logError(s"Failed to update accumulators for $task", e) } } - listenerBus.post(SparkListenerTaskEnd(stageId, taskType, event.reason, event.taskInfo, - event.taskMetrics)) + listenerBus.post(SparkListenerTaskEnd(stageId, stage.latestInfo.attemptId, taskType, + event.reason, event.taskInfo, event.taskMetrics)) stage.pendingTasks -= task task match { case rt: ResultTask[_, _] => @@ -1029,6 +1047,7 @@ class DAGScheduler( case FetchFailed(bmAddress, shuffleId, mapId, reduceId) => // Mark the stage that the reducer was in as unrunnable val failedStage = stageIdToStage(task.stageId) + markStageAsFinished(failedStage, Some("Fetch failure")) runningStages -= failedStage // TODO: Cancel running tasks in the stage logInfo("Marking " + failedStage + " (" + failedStage.name + @@ -1142,7 +1161,7 @@ class DAGScheduler( } val dependentJobs: Seq[ActiveJob] = activeJobs.filter(job => stageDependsOn(job.finalStage, failedStage)).toSeq - failedStage.info.completionTime = Some(clock.getTime()) + failedStage.latestInfo.completionTime = Some(clock.getTime()) for (job <- dependentJobs) { failJobAndIndependentStages(job, s"Job aborted due to stage failure: $reason") } @@ -1182,8 +1201,8 @@ class DAGScheduler( if (runningStages.contains(stage)) { try { // cancelTasks will fail if a SchedulerBackend does not implement killTask taskScheduler.cancelTasks(stageId, shouldInterruptThread) - stage.info.stageFailed(failureReason) - listenerBus.post(SparkListenerStageCompleted(stage.info)) + stage.latestInfo.stageFailed(failureReason) + listenerBus.post(SparkListenerStageCompleted(stage.latestInfo)) } catch { case e: UnsupportedOperationException => logInfo(s"Could not cancel tasks for stage $stageId", e) diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala index d01d318633877..86ca8445a1124 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala @@ -39,7 +39,8 @@ case class SparkListenerStageSubmitted(stageInfo: StageInfo, properties: Propert case class SparkListenerStageCompleted(stageInfo: StageInfo) extends SparkListenerEvent @DeveloperApi -case class SparkListenerTaskStart(stageId: Int, taskInfo: TaskInfo) extends SparkListenerEvent +case class SparkListenerTaskStart(stageId: Int, stageAttemptId: Int, taskInfo: TaskInfo) + extends SparkListenerEvent @DeveloperApi case class SparkListenerTaskGettingResult(taskInfo: TaskInfo) extends SparkListenerEvent @@ -47,6 +48,7 @@ case class SparkListenerTaskGettingResult(taskInfo: TaskInfo) extends SparkListe @DeveloperApi case class SparkListenerTaskEnd( stageId: Int, + stageAttemptId: Int, taskType: String, reason: TaskEndReason, taskInfo: TaskInfo, @@ -75,10 +77,15 @@ case class SparkListenerBlockManagerRemoved(blockManagerId: BlockManagerId) @DeveloperApi case class SparkListenerUnpersistRDD(rddId: Int) extends SparkListenerEvent +/** + * Periodic updates from executors. + * @param execId executor id + * @param taskMetrics sequence of (task id, stage id, stage attempt, metrics) + */ @DeveloperApi case class SparkListenerExecutorMetricsUpdate( execId: String, - taskMetrics: Seq[(Long, Int, TaskMetrics)]) + taskMetrics: Seq[(Long, Int, Int, TaskMetrics)]) extends SparkListenerEvent @DeveloperApi diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala index 800905413d145..071568cdfb429 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala @@ -43,6 +43,9 @@ import org.apache.spark.util.CallSite * stage, the callSite gives the user code that created the RDD being shuffled. For a result * stage, the callSite gives the user code that executes the associated action (e.g. count()). * + * A single stage can consist of multiple attempts. In that case, the latestInfo field will + * be updated for each attempt. + * */ private[spark] class Stage( val id: Int, @@ -71,8 +74,8 @@ private[spark] class Stage( val name = callSite.shortForm val details = callSite.longForm - /** Pointer to the [StageInfo] object, set by DAGScheduler. */ - var info: StageInfo = StageInfo.fromStage(this) + /** Pointer to the latest [StageInfo] object, set by DAGScheduler. */ + var latestInfo: StageInfo = StageInfo.fromStage(this) def isAvailable: Boolean = { if (!isShuffleMap) { @@ -116,6 +119,7 @@ private[spark] class Stage( } } + /** Return a new attempt id, starting with 0. */ def newAttemptId(): Int = { val id = nextAttemptId nextAttemptId += 1 diff --git a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala index 2a407e47a05bd..c6dc3369ba5cc 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala @@ -29,6 +29,7 @@ import org.apache.spark.storage.RDDInfo @DeveloperApi class StageInfo( val stageId: Int, + val attemptId: Int, val name: String, val numTasks: Int, val rddInfos: Seq[RDDInfo], @@ -56,9 +57,15 @@ private[spark] object StageInfo { * shuffle dependencies. Therefore, all ancestor RDDs related to this Stage's RDD through a * sequence of narrow dependencies should also be associated with this Stage. */ - def fromStage(stage: Stage): StageInfo = { + def fromStage(stage: Stage, numTasks: Option[Int] = None): StageInfo = { val ancestorRddInfos = stage.rdd.getNarrowAncestors.map(RDDInfo.fromRdd) val rddInfos = Seq(RDDInfo.fromRdd(stage.rdd)) ++ ancestorRddInfos - new StageInfo(stage.id, stage.name, stage.numTasks, rddInfos, stage.details) + new StageInfo( + stage.id, + stage.attemptId, + stage.name, + numTasks.getOrElse(stage.numTasks), + rddInfos, + stage.details) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 6c0d1b2752a81..ad051e59af86d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -333,12 +333,12 @@ private[spark] class TaskSchedulerImpl( execId: String, taskMetrics: Array[(Long, TaskMetrics)], // taskId -> TaskMetrics blockManagerId: BlockManagerId): Boolean = { - val metricsWithStageIds = taskMetrics.flatMap { - case (id, metrics) => { + + val metricsWithStageIds: Array[(Long, Int, Int, TaskMetrics)] = synchronized { + taskMetrics.flatMap { case (id, metrics) => taskIdToTaskSetId.get(id) .flatMap(activeTaskSets.get) - .map(_.stageId) - .map(x => (id, x, metrics)) + .map(taskSetMgr => (id, taskSetMgr.stageId, taskSetMgr.taskSet.attempt, metrics)) } } dagScheduler.executorHeartbeatReceived(execId, metricsWithStageIds, blockManagerId) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala index 613fa7850bb25..c3ad325156f53 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala @@ -31,9 +31,5 @@ private[spark] class TaskSet( val properties: Properties) { val id: String = stageId + "." + attempt - def kill(interruptThread: Boolean) { - tasks.foreach(_.kill(interruptThread)) - } - override def toString: String = "TaskSet " + id } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala index 0cc51c873727d..2987dc04494a5 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala @@ -24,8 +24,8 @@ import org.apache.spark.ui.{ToolTips, UIUtils} import org.apache.spark.ui.jobs.UIData.StageUIData import org.apache.spark.util.Utils -/** Page showing executor summary */ -private[ui] class ExecutorTable(stageId: Int, parent: JobProgressTab) { +/** Stage summary grouped by executors. */ +private[ui] class ExecutorTable(stageId: Int, stageAttemptId: Int, parent: JobProgressTab) { private val listener = parent.listener def toNodeSeq: Seq[Node] = { @@ -65,7 +65,7 @@ private[ui] class ExecutorTable(stageId: Int, parent: JobProgressTab) { executorIdToAddress.put(executorId, address) } - listener.stageIdToData.get(stageId) match { + listener.stageIdToData.get((stageId, stageAttemptId)) match { case Some(stageData: StageUIData) => stageData.executorSummary.toSeq.sortBy(_._1).map { case (k, v) => diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala index 74cd637d88155..f7f918fd521a9 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala @@ -43,12 +43,16 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { // How many stages to remember val retainedStages = conf.getInt("spark.ui.retainedStages", DEFAULT_RETAINED_STAGES) - val activeStages = HashMap[Int, StageInfo]() + // Map from stageId to StageInfo + val activeStages = new HashMap[Int, StageInfo] + + // Map from (stageId, attemptId) to StageUIData + val stageIdToData = new HashMap[(Int, Int), StageUIData] + val completedStages = ListBuffer[StageInfo]() val failedStages = ListBuffer[StageInfo]() - val stageIdToData = new HashMap[Int, StageUIData] - + // Map from pool name to a hash map (map from stage id to StageInfo). val poolToActiveStages = HashMap[String, HashMap[Int, StageInfo]]() val executorIdToBlockManagerId = HashMap[String, BlockManagerId]() @@ -59,9 +63,8 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { override def onStageCompleted(stageCompleted: SparkListenerStageCompleted) = synchronized { val stage = stageCompleted.stageInfo - val stageId = stage.stageId - val stageData = stageIdToData.getOrElseUpdate(stageId, { - logWarning("Stage completed for unknown stage " + stageId) + val stageData = stageIdToData.getOrElseUpdate((stage.stageId, stage.attemptId), { + logWarning("Stage completed for unknown stage " + stage.stageId) new StageUIData }) @@ -69,8 +72,10 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { stageData.accumulables(id) = info } - poolToActiveStages.get(stageData.schedulingPool).foreach(_.remove(stageId)) - activeStages.remove(stageId) + poolToActiveStages.get(stageData.schedulingPool).foreach { hashMap => + hashMap.remove(stage.stageId) + } + activeStages.remove(stage.stageId) if (stage.failureReason.isEmpty) { completedStages += stage trimIfNecessary(completedStages) @@ -84,7 +89,7 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { private def trimIfNecessary(stages: ListBuffer[StageInfo]) = synchronized { if (stages.size > retainedStages) { val toRemove = math.max(retainedStages / 10, 1) - stages.take(toRemove).foreach { s => stageIdToData.remove(s.stageId) } + stages.take(toRemove).foreach { s => stageIdToData.remove((s.stageId, s.attemptId)) } stages.trimStart(toRemove) } } @@ -98,21 +103,21 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { p => p.getProperty("spark.scheduler.pool", DEFAULT_POOL_NAME) }.getOrElse(DEFAULT_POOL_NAME) - val stageData = stageIdToData.getOrElseUpdate(stage.stageId, new StageUIData) + val stageData = stageIdToData.getOrElseUpdate((stage.stageId, stage.attemptId), new StageUIData) stageData.schedulingPool = poolName stageData.description = Option(stageSubmitted.properties).flatMap { p => Option(p.getProperty(SparkContext.SPARK_JOB_DESCRIPTION)) } - val stages = poolToActiveStages.getOrElseUpdate(poolName, new HashMap[Int, StageInfo]()) + val stages = poolToActiveStages.getOrElseUpdate(poolName, new HashMap[Int, StageInfo]) stages(stage.stageId) = stage } override def onTaskStart(taskStart: SparkListenerTaskStart) = synchronized { val taskInfo = taskStart.taskInfo if (taskInfo != null) { - val stageData = stageIdToData.getOrElseUpdate(taskStart.stageId, { + val stageData = stageIdToData.getOrElseUpdate((taskStart.stageId, taskStart.stageAttemptId), { logWarning("Task start for unknown stage " + taskStart.stageId) new StageUIData }) @@ -128,8 +133,11 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { override def onTaskEnd(taskEnd: SparkListenerTaskEnd) = synchronized { val info = taskEnd.taskInfo - if (info != null) { - val stageData = stageIdToData.getOrElseUpdate(taskEnd.stageId, { + // If stage attempt id is -1, it means the DAGScheduler had no idea which attempt this task + // compeletion event is for. Let's just drop it here. This means we might have some speculation + // tasks on the web ui that's never marked as complete. + if (info != null && taskEnd.stageAttemptId != -1) { + val stageData = stageIdToData.getOrElseUpdate((taskEnd.stageId, taskEnd.stageAttemptId), { logWarning("Task end for unknown stage " + taskEnd.stageId) new StageUIData }) @@ -222,8 +230,8 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { } override def onExecutorMetricsUpdate(executorMetricsUpdate: SparkListenerExecutorMetricsUpdate) { - for ((taskId, sid, taskMetrics) <- executorMetricsUpdate.taskMetrics) { - val stageData = stageIdToData.getOrElseUpdate(sid, { + for ((taskId, sid, sAttempt, taskMetrics) <- executorMetricsUpdate.taskMetrics) { + val stageData = stageIdToData.getOrElseUpdate((sid, sAttempt), { logWarning("Metrics update for task in unknown stage " + sid) new StageUIData }) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index d4eb02722ad12..db01be596e073 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -34,7 +34,8 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { def render(request: HttpServletRequest): Seq[Node] = { listener.synchronized { val stageId = request.getParameter("id").toInt - val stageDataOption = listener.stageIdToData.get(stageId) + val stageAttemptId = request.getParameter("attempt").toInt + val stageDataOption = listener.stageIdToData.get((stageId, stageAttemptId)) if (stageDataOption.isEmpty || stageDataOption.get.taskData.isEmpty) { val content = @@ -42,14 +43,15 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") {

    Summary Metrics

    No tasks have started yet

    Tasks

    No tasks have started yet - return UIUtils.headerSparkPage("Details for Stage %s".format(stageId), content, parent) + return UIUtils.headerSparkPage( + s"Details for Stage $stageId (Attempt $stageAttemptId)", content, parent) } val stageData = stageDataOption.get val tasks = stageData.taskData.values.toSeq.sortBy(_.taskInfo.launchTime) val numCompleted = tasks.count(_.taskInfo.finished) - val accumulables = listener.stageIdToData(stageId).accumulables + val accumulables = listener.stageIdToData((stageId, stageAttemptId)).accumulables val hasInput = stageData.inputBytes > 0 val hasShuffleRead = stageData.shuffleReadBytes > 0 val hasShuffleWrite = stageData.shuffleWriteBytes > 0 @@ -211,7 +213,8 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { def quantileRow(data: Seq[Node]): Seq[Node] = {data} Some(UIUtils.listingTable(quantileHeaders, quantileRow, listings, fixedWidth = true)) } - val executorTable = new ExecutorTable(stageId, parent) + + val executorTable = new ExecutorTable(stageId, stageAttemptId, parent) val maybeAccumulableTable: Seq[Node] = if (accumulables.size > 0) {

    Accumulators

    ++ accumulableTable } else Seq() diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala index 16ad0df45aa0d..2e67310594784 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala @@ -97,8 +97,8 @@ private[ui] class StageTableBase( } // scalastyle:on - val nameLinkUri ="%s/stages/stage?id=%s" - .format(UIUtils.prependBaseUri(parent.basePath), s.stageId) + val nameLinkUri ="%s/stages/stage?id=%s&attempt=%s" + .format(UIUtils.prependBaseUri(parent.basePath), s.stageId, s.attemptId) val nameLink = {s.name} val cachedRddInfos = s.rddInfos.filter(_.numCachedPartitions > 0) @@ -121,7 +121,7 @@ private[ui] class StageTableBase( } val stageDesc = for { - stageData <- listener.stageIdToData.get(s.stageId) + stageData <- listener.stageIdToData.get((s.stageId, s.attemptId)) desc <- stageData.description } yield {
    {desc}
    @@ -131,7 +131,7 @@ private[ui] class StageTableBase( } protected def stageRow(s: StageInfo): Seq[Node] = { - val stageDataOption = listener.stageIdToData.get(s.stageId) + val stageDataOption = listener.stageIdToData.get((s.stageId, s.attemptId)) if (stageDataOption.isEmpty) { return {s.stageId}No data available for this stage } @@ -154,7 +154,11 @@ private[ui] class StageTableBase( val shuffleWrite = stageData.shuffleWriteBytes val shuffleWriteWithUnit = if (shuffleWrite > 0) Utils.bytesToString(shuffleWrite) else "" - {s.stageId} ++ + {if (s.attemptId > 0) { + {s.stageId} (retry {s.attemptId}) + } else { + {s.stageId} + }} ++ {if (isFairScheduler) { Utils.getFormattedClassName(taskStart)) ~ ("Stage ID" -> taskStart.stageId) ~ + ("Stage Attempt ID" -> taskStart.stageAttemptId) ~ ("Task Info" -> taskInfoToJson(taskInfo)) } @@ -112,6 +113,7 @@ private[spark] object JsonProtocol { val taskMetricsJson = if (taskMetrics != null) taskMetricsToJson(taskMetrics) else JNothing ("Event" -> Utils.getFormattedClassName(taskEnd)) ~ ("Stage ID" -> taskEnd.stageId) ~ + ("Stage Attempt ID" -> taskEnd.stageAttemptId) ~ ("Task Type" -> taskEnd.taskType) ~ ("Task End Reason" -> taskEndReason) ~ ("Task Info" -> taskInfoToJson(taskInfo)) ~ @@ -187,6 +189,7 @@ private[spark] object JsonProtocol { val completionTime = stageInfo.completionTime.map(JInt(_)).getOrElse(JNothing) val failureReason = stageInfo.failureReason.map(JString(_)).getOrElse(JNothing) ("Stage ID" -> stageInfo.stageId) ~ + ("Stage Attempt ID" -> stageInfo.attemptId) ~ ("Stage Name" -> stageInfo.name) ~ ("Number of Tasks" -> stageInfo.numTasks) ~ ("RDD Info" -> rddInfo) ~ @@ -419,8 +422,9 @@ private[spark] object JsonProtocol { def taskStartFromJson(json: JValue): SparkListenerTaskStart = { val stageId = (json \ "Stage ID").extract[Int] + val stageAttemptId = (json \ "Stage Attempt ID").extractOpt[Int].getOrElse(0) val taskInfo = taskInfoFromJson(json \ "Task Info") - SparkListenerTaskStart(stageId, taskInfo) + SparkListenerTaskStart(stageId, stageAttemptId, taskInfo) } def taskGettingResultFromJson(json: JValue): SparkListenerTaskGettingResult = { @@ -430,11 +434,12 @@ private[spark] object JsonProtocol { def taskEndFromJson(json: JValue): SparkListenerTaskEnd = { val stageId = (json \ "Stage ID").extract[Int] + val stageAttemptId = (json \ "Stage Attempt ID").extractOpt[Int].getOrElse(0) val taskType = (json \ "Task Type").extract[String] val taskEndReason = taskEndReasonFromJson(json \ "Task End Reason") val taskInfo = taskInfoFromJson(json \ "Task Info") val taskMetrics = taskMetricsFromJson(json \ "Task Metrics") - SparkListenerTaskEnd(stageId, taskType, taskEndReason, taskInfo, taskMetrics) + SparkListenerTaskEnd(stageId, stageAttemptId, taskType, taskEndReason, taskInfo, taskMetrics) } def jobStartFromJson(json: JValue): SparkListenerJobStart = { @@ -492,6 +497,7 @@ private[spark] object JsonProtocol { def stageInfoFromJson(json: JValue): StageInfo = { val stageId = (json \ "Stage ID").extract[Int] + val attemptId = (json \ "Attempt ID").extractOpt[Int].getOrElse(0) val stageName = (json \ "Stage Name").extract[String] val numTasks = (json \ "Number of Tasks").extract[Int] val rddInfos = (json \ "RDD Info").extract[List[JValue]].map(rddInfoFromJson(_)) @@ -504,7 +510,7 @@ private[spark] object JsonProtocol { case None => Seq[AccumulableInfo]() } - val stageInfo = new StageInfo(stageId, stageName, numTasks, rddInfos, details) + val stageInfo = new StageInfo(stageId, attemptId, stageName, numTasks, rddInfos, details) stageInfo.submissionTime = submissionTime stageInfo.completionTime = completionTime stageInfo.failureReason = failureReason diff --git a/core/src/test/scala/org/apache/spark/storage/StorageStatusListenerSuite.scala b/core/src/test/scala/org/apache/spark/storage/StorageStatusListenerSuite.scala index 51fb646a3cb61..7671cb969a26b 100644 --- a/core/src/test/scala/org/apache/spark/storage/StorageStatusListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/StorageStatusListenerSuite.scala @@ -69,10 +69,10 @@ class StorageStatusListenerSuite extends FunSuite { // Task end with no updated blocks assert(listener.executorIdToStorageStatus("big").numBlocks === 0) assert(listener.executorIdToStorageStatus("fat").numBlocks === 0) - listener.onTaskEnd(SparkListenerTaskEnd(1, "obliteration", Success, taskInfo1, taskMetrics)) + listener.onTaskEnd(SparkListenerTaskEnd(1, 0, "obliteration", Success, taskInfo1, taskMetrics)) assert(listener.executorIdToStorageStatus("big").numBlocks === 0) assert(listener.executorIdToStorageStatus("fat").numBlocks === 0) - listener.onTaskEnd(SparkListenerTaskEnd(1, "obliteration", Success, taskInfo2, taskMetrics)) + listener.onTaskEnd(SparkListenerTaskEnd(1, 0, "obliteration", Success, taskInfo2, taskMetrics)) assert(listener.executorIdToStorageStatus("big").numBlocks === 0) assert(listener.executorIdToStorageStatus("fat").numBlocks === 0) } @@ -92,13 +92,13 @@ class StorageStatusListenerSuite extends FunSuite { // Task end with new blocks assert(listener.executorIdToStorageStatus("big").numBlocks === 0) assert(listener.executorIdToStorageStatus("fat").numBlocks === 0) - listener.onTaskEnd(SparkListenerTaskEnd(1, "obliteration", Success, taskInfo1, taskMetrics1)) + listener.onTaskEnd(SparkListenerTaskEnd(1, 0, "obliteration", Success, taskInfo1, taskMetrics1)) assert(listener.executorIdToStorageStatus("big").numBlocks === 2) assert(listener.executorIdToStorageStatus("fat").numBlocks === 0) assert(listener.executorIdToStorageStatus("big").containsBlock(RDDBlockId(1, 1))) assert(listener.executorIdToStorageStatus("big").containsBlock(RDDBlockId(1, 2))) assert(listener.executorIdToStorageStatus("fat").numBlocks === 0) - listener.onTaskEnd(SparkListenerTaskEnd(1, "obliteration", Success, taskInfo2, taskMetrics2)) + listener.onTaskEnd(SparkListenerTaskEnd(1, 0, "obliteration", Success, taskInfo2, taskMetrics2)) assert(listener.executorIdToStorageStatus("big").numBlocks === 2) assert(listener.executorIdToStorageStatus("fat").numBlocks === 1) assert(listener.executorIdToStorageStatus("big").containsBlock(RDDBlockId(1, 1))) @@ -111,13 +111,14 @@ class StorageStatusListenerSuite extends FunSuite { val droppedBlock3 = (RDDBlockId(4, 0), BlockStatus(StorageLevel.NONE, 0L, 0L, 0L)) taskMetrics1.updatedBlocks = Some(Seq(droppedBlock1, droppedBlock3)) taskMetrics2.updatedBlocks = Some(Seq(droppedBlock2, droppedBlock3)) - listener.onTaskEnd(SparkListenerTaskEnd(1, "obliteration", Success, taskInfo1, taskMetrics1)) + + listener.onTaskEnd(SparkListenerTaskEnd(1, 0, "obliteration", Success, taskInfo1, taskMetrics1)) assert(listener.executorIdToStorageStatus("big").numBlocks === 1) assert(listener.executorIdToStorageStatus("fat").numBlocks === 1) assert(!listener.executorIdToStorageStatus("big").containsBlock(RDDBlockId(1, 1))) assert(listener.executorIdToStorageStatus("big").containsBlock(RDDBlockId(1, 2))) assert(listener.executorIdToStorageStatus("fat").containsBlock(RDDBlockId(4, 0))) - listener.onTaskEnd(SparkListenerTaskEnd(1, "obliteration", Success, taskInfo2, taskMetrics2)) + listener.onTaskEnd(SparkListenerTaskEnd(1, 0, "obliteration", Success, taskInfo2, taskMetrics2)) assert(listener.executorIdToStorageStatus("big").numBlocks === 1) assert(listener.executorIdToStorageStatus("fat").numBlocks === 0) assert(!listener.executorIdToStorageStatus("big").containsBlock(RDDBlockId(1, 1))) @@ -135,8 +136,8 @@ class StorageStatusListenerSuite extends FunSuite { val block3 = (RDDBlockId(4, 0), BlockStatus(StorageLevel.DISK_ONLY, 0L, 300L, 0L)) taskMetrics1.updatedBlocks = Some(Seq(block1, block2)) taskMetrics2.updatedBlocks = Some(Seq(block3)) - listener.onTaskEnd(SparkListenerTaskEnd(1, "obliteration", Success, taskInfo1, taskMetrics1)) - listener.onTaskEnd(SparkListenerTaskEnd(1, "obliteration", Success, taskInfo1, taskMetrics2)) + listener.onTaskEnd(SparkListenerTaskEnd(1, 0, "obliteration", Success, taskInfo1, taskMetrics1)) + listener.onTaskEnd(SparkListenerTaskEnd(1, 0, "obliteration", Success, taskInfo1, taskMetrics2)) assert(listener.executorIdToStorageStatus("big").numBlocks === 3) // Unpersist RDD diff --git a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala index 147ec0bc52e39..3370dd4156c3f 100644 --- a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala @@ -34,12 +34,12 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc val listener = new JobProgressListener(conf) def createStageStartEvent(stageId: Int) = { - val stageInfo = new StageInfo(stageId, stageId.toString, 0, null, "") + val stageInfo = new StageInfo(stageId, 0, stageId.toString, 0, null, "") SparkListenerStageSubmitted(stageInfo) } def createStageEndEvent(stageId: Int) = { - val stageInfo = new StageInfo(stageId, stageId.toString, 0, null, "") + val stageInfo = new StageInfo(stageId, 0, stageId.toString, 0, null, "") SparkListenerStageCompleted(stageInfo) } @@ -70,33 +70,37 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc taskInfo.finishTime = 1 var task = new ShuffleMapTask(0) val taskType = Utils.getFormattedClassName(task) - listener.onTaskEnd(SparkListenerTaskEnd(task.stageId, taskType, Success, taskInfo, taskMetrics)) - assert(listener.stageIdToData.getOrElse(0, fail()).executorSummary.getOrElse("exe-1", fail()) - .shuffleRead === 1000) + listener.onTaskEnd( + SparkListenerTaskEnd(task.stageId, 0, taskType, Success, taskInfo, taskMetrics)) + assert(listener.stageIdToData.getOrElse((0, 0), fail()) + .executorSummary.getOrElse("exe-1", fail()).shuffleRead === 1000) // finish a task with unknown executor-id, nothing should happen taskInfo = new TaskInfo(1234L, 0, 1, 1000L, "exe-unknown", "host1", TaskLocality.NODE_LOCAL, true) taskInfo.finishTime = 1 task = new ShuffleMapTask(0) - listener.onTaskEnd(SparkListenerTaskEnd(task.stageId, taskType, Success, taskInfo, taskMetrics)) + listener.onTaskEnd( + SparkListenerTaskEnd(task.stageId, 0, taskType, Success, taskInfo, taskMetrics)) assert(listener.stageIdToData.size === 1) // finish this task, should get updated duration taskInfo = new TaskInfo(1235L, 0, 1, 0L, "exe-1", "host1", TaskLocality.NODE_LOCAL, false) taskInfo.finishTime = 1 task = new ShuffleMapTask(0) - listener.onTaskEnd(SparkListenerTaskEnd(task.stageId, taskType, Success, taskInfo, taskMetrics)) - assert(listener.stageIdToData.getOrElse(0, fail()).executorSummary.getOrElse("exe-1", fail()) - .shuffleRead === 2000) + listener.onTaskEnd( + SparkListenerTaskEnd(task.stageId, 0, taskType, Success, taskInfo, taskMetrics)) + assert(listener.stageIdToData.getOrElse((0, 0), fail()) + .executorSummary.getOrElse("exe-1", fail()).shuffleRead === 2000) // finish this task, should get updated duration taskInfo = new TaskInfo(1236L, 0, 2, 0L, "exe-2", "host1", TaskLocality.NODE_LOCAL, false) taskInfo.finishTime = 1 task = new ShuffleMapTask(0) - listener.onTaskEnd(SparkListenerTaskEnd(task.stageId, taskType, Success, taskInfo, taskMetrics)) - assert(listener.stageIdToData.getOrElse(0, fail()).executorSummary.getOrElse("exe-2", fail()) - .shuffleRead === 1000) + listener.onTaskEnd( + SparkListenerTaskEnd(task.stageId, 0, taskType, Success, taskInfo, taskMetrics)) + assert(listener.stageIdToData.getOrElse((0, 0), fail()) + .executorSummary.getOrElse("exe-2", fail()).shuffleRead === 1000) } test("test task success vs failure counting for different task end reasons") { @@ -119,16 +123,18 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc UnknownReason) var failCount = 0 for (reason <- taskFailedReasons) { - listener.onTaskEnd(SparkListenerTaskEnd(task.stageId, taskType, reason, taskInfo, metrics)) + listener.onTaskEnd( + SparkListenerTaskEnd(task.stageId, 0, taskType, reason, taskInfo, metrics)) failCount += 1 - assert(listener.stageIdToData(task.stageId).numCompleteTasks === 0) - assert(listener.stageIdToData(task.stageId).numFailedTasks === failCount) + assert(listener.stageIdToData((task.stageId, 0)).numCompleteTasks === 0) + assert(listener.stageIdToData((task.stageId, 0)).numFailedTasks === failCount) } // Make sure we count success as success. - listener.onTaskEnd(SparkListenerTaskEnd(task.stageId, taskType, Success, taskInfo, metrics)) - assert(listener.stageIdToData(task.stageId).numCompleteTasks === 1) - assert(listener.stageIdToData(task.stageId).numFailedTasks === failCount) + listener.onTaskEnd( + SparkListenerTaskEnd(task.stageId, 1, taskType, Success, taskInfo, metrics)) + assert(listener.stageIdToData((task.stageId, 1)).numCompleteTasks === 1) + assert(listener.stageIdToData((task.stageId, 0)).numFailedTasks === failCount) } test("test update metrics") { @@ -163,18 +169,18 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc taskInfo } - listener.onTaskStart(SparkListenerTaskStart(0, makeTaskInfo(1234L))) - listener.onTaskStart(SparkListenerTaskStart(0, makeTaskInfo(1235L))) - listener.onTaskStart(SparkListenerTaskStart(1, makeTaskInfo(1236L))) - listener.onTaskStart(SparkListenerTaskStart(1, makeTaskInfo(1237L))) + listener.onTaskStart(SparkListenerTaskStart(0, 0, makeTaskInfo(1234L))) + listener.onTaskStart(SparkListenerTaskStart(0, 0, makeTaskInfo(1235L))) + listener.onTaskStart(SparkListenerTaskStart(1, 0, makeTaskInfo(1236L))) + listener.onTaskStart(SparkListenerTaskStart(1, 0, makeTaskInfo(1237L))) listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate(execId, Array( - (1234L, 0, makeTaskMetrics(0)), - (1235L, 0, makeTaskMetrics(100)), - (1236L, 1, makeTaskMetrics(200))))) + (1234L, 0, 0, makeTaskMetrics(0)), + (1235L, 0, 0, makeTaskMetrics(100)), + (1236L, 1, 0, makeTaskMetrics(200))))) - var stage0Data = listener.stageIdToData.get(0).get - var stage1Data = listener.stageIdToData.get(1).get + var stage0Data = listener.stageIdToData.get((0, 0)).get + var stage1Data = listener.stageIdToData.get((1, 0)).get assert(stage0Data.shuffleReadBytes == 102) assert(stage1Data.shuffleReadBytes == 201) assert(stage0Data.shuffleWriteBytes == 106) @@ -195,14 +201,14 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc .totalBlocksFetched == 202) // task that was included in a heartbeat - listener.onTaskEnd(SparkListenerTaskEnd(0, taskType, Success, makeTaskInfo(1234L, 1), + listener.onTaskEnd(SparkListenerTaskEnd(0, 0, taskType, Success, makeTaskInfo(1234L, 1), makeTaskMetrics(300))) // task that wasn't included in a heartbeat - listener.onTaskEnd(SparkListenerTaskEnd(1, taskType, Success, makeTaskInfo(1237L, 1), + listener.onTaskEnd(SparkListenerTaskEnd(1, 0, taskType, Success, makeTaskInfo(1237L, 1), makeTaskMetrics(400))) - stage0Data = listener.stageIdToData.get(0).get - stage1Data = listener.stageIdToData.get(1).get + stage0Data = listener.stageIdToData.get((0, 0)).get + stage1Data = listener.stageIdToData.get((1, 0)).get assert(stage0Data.shuffleReadBytes == 402) assert(stage1Data.shuffleReadBytes == 602) assert(stage0Data.shuffleWriteBytes == 406) diff --git a/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala b/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala index 6e68dcb3425aa..b860177705d84 100644 --- a/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala @@ -53,7 +53,7 @@ class StorageTabSuite extends FunSuite with BeforeAndAfter { assert(storageListener.rddInfoList.isEmpty) // 2 RDDs are known, but none are cached - val stageInfo0 = new StageInfo(0, "0", 100, Seq(rddInfo0, rddInfo1), "details") + val stageInfo0 = new StageInfo(0, 0, "0", 100, Seq(rddInfo0, rddInfo1), "details") bus.postToAll(SparkListenerStageSubmitted(stageInfo0)) assert(storageListener._rddInfoMap.size === 2) assert(storageListener.rddInfoList.isEmpty) @@ -63,7 +63,7 @@ class StorageTabSuite extends FunSuite with BeforeAndAfter { val rddInfo3Cached = rddInfo3 rddInfo2Cached.numCachedPartitions = 1 rddInfo3Cached.numCachedPartitions = 1 - val stageInfo1 = new StageInfo(1, "0", 100, Seq(rddInfo2Cached, rddInfo3Cached), "details") + val stageInfo1 = new StageInfo(1, 0, "0", 100, Seq(rddInfo2Cached, rddInfo3Cached), "details") bus.postToAll(SparkListenerStageSubmitted(stageInfo1)) assert(storageListener._rddInfoMap.size === 4) assert(storageListener.rddInfoList.size === 2) @@ -71,7 +71,7 @@ class StorageTabSuite extends FunSuite with BeforeAndAfter { // Submitting RDDInfos with duplicate IDs does nothing val rddInfo0Cached = new RDDInfo(0, "freedom", 100, StorageLevel.MEMORY_ONLY) rddInfo0Cached.numCachedPartitions = 1 - val stageInfo0Cached = new StageInfo(0, "0", 100, Seq(rddInfo0), "details") + val stageInfo0Cached = new StageInfo(0, 0, "0", 100, Seq(rddInfo0), "details") bus.postToAll(SparkListenerStageSubmitted(stageInfo0Cached)) assert(storageListener._rddInfoMap.size === 4) assert(storageListener.rddInfoList.size === 2) @@ -87,7 +87,7 @@ class StorageTabSuite extends FunSuite with BeforeAndAfter { val rddInfo1Cached = rddInfo1 rddInfo0Cached.numCachedPartitions = 1 rddInfo1Cached.numCachedPartitions = 1 - val stageInfo0 = new StageInfo(0, "0", 100, Seq(rddInfo0Cached, rddInfo1Cached), "details") + val stageInfo0 = new StageInfo(0, 0, "0", 100, Seq(rddInfo0Cached, rddInfo1Cached), "details") bus.postToAll(SparkListenerStageSubmitted(stageInfo0)) assert(storageListener._rddInfoMap.size === 2) assert(storageListener.rddInfoList.size === 2) @@ -106,7 +106,7 @@ class StorageTabSuite extends FunSuite with BeforeAndAfter { val myRddInfo0 = rddInfo0 val myRddInfo1 = rddInfo1 val myRddInfo2 = rddInfo2 - val stageInfo0 = new StageInfo(0, "0", 100, Seq(myRddInfo0, myRddInfo1, myRddInfo2), "details") + val stageInfo0 = new StageInfo(0, 0, "0", 100, Seq(myRddInfo0, myRddInfo1, myRddInfo2), "details") bus.postToAll(SparkListenerBlockManagerAdded(bm1, 1000L)) bus.postToAll(SparkListenerStageSubmitted(stageInfo0)) assert(storageListener._rddInfoMap.size === 3) @@ -116,7 +116,7 @@ class StorageTabSuite extends FunSuite with BeforeAndAfter { assert(!storageListener._rddInfoMap(2).isCached) // Task end with no updated blocks. This should not change anything. - bus.postToAll(SparkListenerTaskEnd(0, "obliteration", Success, taskInfo, new TaskMetrics)) + bus.postToAll(SparkListenerTaskEnd(0, 0, "obliteration", Success, taskInfo, new TaskMetrics)) assert(storageListener._rddInfoMap.size === 3) assert(storageListener.rddInfoList.size === 0) @@ -128,7 +128,7 @@ class StorageTabSuite extends FunSuite with BeforeAndAfter { (RDDBlockId(0, 102), BlockStatus(memAndDisk, 400L, 0L, 200L)), (RDDBlockId(1, 20), BlockStatus(memAndDisk, 0L, 240L, 0L)) )) - bus.postToAll(SparkListenerTaskEnd(1, "obliteration", Success, taskInfo, metrics1)) + bus.postToAll(SparkListenerTaskEnd(1, 0, "obliteration", Success, taskInfo, metrics1)) assert(storageListener._rddInfoMap(0).memSize === 800L) assert(storageListener._rddInfoMap(0).diskSize === 400L) assert(storageListener._rddInfoMap(0).tachyonSize === 200L) @@ -150,7 +150,7 @@ class StorageTabSuite extends FunSuite with BeforeAndAfter { (RDDBlockId(2, 40), BlockStatus(none, 0L, 0L, 0L)), // doesn't actually exist (RDDBlockId(4, 80), BlockStatus(none, 0L, 0L, 0L)) // doesn't actually exist )) - bus.postToAll(SparkListenerTaskEnd(2, "obliteration", Success, taskInfo, metrics2)) + bus.postToAll(SparkListenerTaskEnd(2, 0, "obliteration", Success, taskInfo, metrics2)) assert(storageListener._rddInfoMap(0).memSize === 400L) assert(storageListener._rddInfoMap(0).diskSize === 400L) assert(storageListener._rddInfoMap(0).tachyonSize === 200L) diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index 97ffb07662482..2fd3b9cfd221a 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -35,13 +35,13 @@ class JsonProtocolSuite extends FunSuite { val stageSubmitted = SparkListenerStageSubmitted(makeStageInfo(100, 200, 300, 400L, 500L), properties) val stageCompleted = SparkListenerStageCompleted(makeStageInfo(101, 201, 301, 401L, 501L)) - val taskStart = SparkListenerTaskStart(111, makeTaskInfo(222L, 333, 1, 444L, false)) + val taskStart = SparkListenerTaskStart(111, 0, makeTaskInfo(222L, 333, 1, 444L, false)) val taskGettingResult = SparkListenerTaskGettingResult(makeTaskInfo(1000L, 2000, 5, 3000L, true)) - val taskEnd = SparkListenerTaskEnd(1, "ShuffleMapTask", Success, + val taskEnd = SparkListenerTaskEnd(1, 0, "ShuffleMapTask", Success, makeTaskInfo(123L, 234, 67, 345L, false), makeTaskMetrics(300L, 400L, 500L, 600L, 700, 800, hasHadoopInput = false)) - val taskEndWithHadoopInput = SparkListenerTaskEnd(1, "ShuffleMapTask", Success, + val taskEndWithHadoopInput = SparkListenerTaskEnd(1, 0, "ShuffleMapTask", Success, makeTaskInfo(123L, 234, 67, 345L, false), makeTaskMetrics(300L, 400L, 500L, 600L, 700, 800, hasHadoopInput = true)) val jobStart = SparkListenerJobStart(10, Seq[Int](1, 2, 3, 4), properties) @@ -397,7 +397,8 @@ class JsonProtocolSuite extends FunSuite { private def assertJsonStringEquals(json1: String, json2: String) { val formatJsonString = (json: String) => json.replaceAll("[\\s|]", "") - assert(formatJsonString(json1) === formatJsonString(json2)) + assert(formatJsonString(json1) === formatJsonString(json2), + s"input ${formatJsonString(json1)} got ${formatJsonString(json2)}") } private def assertSeqEquals[T](seq1: Seq[T], seq2: Seq[T], assertEquals: (T, T) => Unit) { @@ -485,7 +486,7 @@ class JsonProtocolSuite extends FunSuite { private def makeStageInfo(a: Int, b: Int, c: Int, d: Long, e: Long) = { val rddInfos = (0 until a % 5).map { i => makeRddInfo(a + i, b + i, c + i, d + i, e + i) } - val stageInfo = new StageInfo(a, "greetings", b, rddInfos, "details") + val stageInfo = new StageInfo(a, 0, "greetings", b, rddInfos, "details") val (acc1, acc2) = (makeAccumulableInfo(1), makeAccumulableInfo(2)) stageInfo.accumulables(acc1.id) = acc1 stageInfo.accumulables(acc2.id) = acc2 @@ -558,84 +559,246 @@ class JsonProtocolSuite extends FunSuite { private val stageSubmittedJsonString = """ - {"Event":"SparkListenerStageSubmitted","Stage Info":{"Stage ID":100,"Stage Name": - "greetings","Number of Tasks":200,"RDD Info":[],"Details":"details", - "Accumulables":[{"ID":2,"Name":"Accumulable2","Update":"delta2","Value":"val2"}, - {"ID":1,"Name":"Accumulable1","Update":"delta1","Value":"val1"}]},"Properties": - {"France":"Paris","Germany":"Berlin","Russia":"Moscow","Ukraine":"Kiev"}} + |{ + | "Event": "SparkListenerStageSubmitted", + | "Stage Info": { + | "Stage ID": 100, + | "Stage Attempt ID": 0, + | "Stage Name": "greetings", + | "Number of Tasks": 200, + | "RDD Info": [], + | "Details": "details", + | "Accumulables": [ + | { + | "ID": 2, + | "Name": "Accumulable2", + | "Update": "delta2", + | "Value": "val2" + | }, + | { + | "ID": 1, + | "Name": "Accumulable1", + | "Update": "delta1", + | "Value": "val1" + | } + | ] + | }, + | "Properties": { + | "France": "Paris", + | "Germany": "Berlin", + | "Russia": "Moscow", + | "Ukraine": "Kiev" + | } + |} """ private val stageCompletedJsonString = """ - {"Event":"SparkListenerStageCompleted","Stage Info":{"Stage ID":101,"Stage Name": - "greetings","Number of Tasks":201,"RDD Info":[{"RDD ID":101,"Name":"mayor","Storage - Level":{"Use Disk":true,"Use Memory":true,"Use Tachyon":false,"Deserialized":true, - "Replication":1},"Number of Partitions":201,"Number of Cached Partitions":301, - "Memory Size":401,"Tachyon Size":0,"Disk Size":501}],"Details":"details", - "Accumulables":[{"ID":2,"Name":"Accumulable2","Update":"delta2","Value":"val2"}, - {"ID":1,"Name":"Accumulable1","Update":"delta1","Value":"val1"}]}} + |{ + | "Event": "SparkListenerStageCompleted", + | "Stage Info": { + | "Stage ID": 101, + | "Stage Attempt ID": 0, + | "Stage Name": "greetings", + | "Number of Tasks": 201, + | "RDD Info": [ + | { + | "RDD ID": 101, + | "Name": "mayor", + | "Storage Level": { + | "Use Disk": true, + | "Use Memory": true, + | "Use Tachyon": false, + | "Deserialized": true, + | "Replication": 1 + | }, + | "Number of Partitions": 201, + | "Number of Cached Partitions": 301, + | "Memory Size": 401, + | "Tachyon Size": 0, + | "Disk Size": 501 + | } + | ], + | "Details": "details", + | "Accumulables": [ + | { + | "ID": 2, + | "Name": "Accumulable2", + | "Update": "delta2", + | "Value": "val2" + | }, + | { + | "ID": 1, + | "Name": "Accumulable1", + | "Update": "delta1", + | "Value": "val1" + | } + | ] + | } + |} """ private val taskStartJsonString = """ - |{"Event":"SparkListenerTaskStart","Stage ID":111,"Task Info":{"Task ID":222, - |"Index":333,"Attempt":1,"Launch Time":444,"Executor ID":"executor","Host":"your kind sir", - |"Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0, - |"Failed":false,"Accumulables":[{"ID":1,"Name":"Accumulable1","Update":"delta1", - |"Value":"val1"},{"ID":2,"Name":"Accumulable2","Update":"delta2","Value":"val2"}, - |{"ID":3,"Name":"Accumulable3","Update":"delta3","Value":"val3"}]}} + |{ + | "Event": "SparkListenerTaskStart", + | "Stage ID": 111, + | "Stage Attempt ID": 0, + | "Task Info": { + | "Task ID": 222, + | "Index": 333, + | "Attempt": 1, + | "Launch Time": 444, + | "Executor ID": "executor", + | "Host": "your kind sir", + | "Locality": "NODE_LOCAL", + | "Speculative": false, + | "Getting Result Time": 0, + | "Finish Time": 0, + | "Failed": false, + | "Accumulables": [ + | { + | "ID": 1, + | "Name": "Accumulable1", + | "Update": "delta1", + | "Value": "val1" + | }, + | { + | "ID": 2, + | "Name": "Accumulable2", + | "Update": "delta2", + | "Value": "val2" + | }, + | { + | "ID": 3, + | "Name": "Accumulable3", + | "Update": "delta3", + | "Value": "val3" + | } + | ] + | } + |} """.stripMargin private val taskGettingResultJsonString = """ - |{"Event":"SparkListenerTaskGettingResult","Task Info": - | {"Task ID":1000,"Index":2000,"Attempt":5,"Launch Time":3000,"Executor ID":"executor", - | "Host":"your kind sir","Locality":"NODE_LOCAL","Speculative":true,"Getting Result Time":0, - | "Finish Time":0,"Failed":false, - | "Accumulables":[{"ID":1,"Name":"Accumulable1","Update":"delta1", - | "Value":"val1"},{"ID":2,"Name":"Accumulable2","Update":"delta2","Value":"val2"}, - | {"ID":3,"Name":"Accumulable3","Update":"delta3","Value":"val3"}] + |{ + | "Event": "SparkListenerTaskGettingResult", + | "Task Info": { + | "Task ID": 1000, + | "Index": 2000, + | "Attempt": 5, + | "Launch Time": 3000, + | "Executor ID": "executor", + | "Host": "your kind sir", + | "Locality": "NODE_LOCAL", + | "Speculative": true, + | "Getting Result Time": 0, + | "Finish Time": 0, + | "Failed": false, + | "Accumulables": [ + | { + | "ID": 1, + | "Name": "Accumulable1", + | "Update": "delta1", + | "Value": "val1" + | }, + | { + | "ID": 2, + | "Name": "Accumulable2", + | "Update": "delta2", + | "Value": "val2" + | }, + | { + | "ID": 3, + | "Name": "Accumulable3", + | "Update": "delta3", + | "Value": "val3" + | } + | ] | } |} """.stripMargin private val taskEndJsonString = """ - |{"Event":"SparkListenerTaskEnd","Stage ID":1,"Task Type":"ShuffleMapTask", - |"Task End Reason":{"Reason":"Success"}, - |"Task Info":{ - | "Task ID":123,"Index":234,"Attempt":67,"Launch Time":345,"Executor ID":"executor", - | "Host":"your kind sir","Locality":"NODE_LOCAL","Speculative":false, - | "Getting Result Time":0,"Finish Time":0,"Failed":false, - | "Accumulables":[{"ID":1,"Name":"Accumulable1","Update":"delta1", - | "Value":"val1"},{"ID":2,"Name":"Accumulable2","Update":"delta2","Value":"val2"}, - | {"ID":3,"Name":"Accumulable3","Update":"delta3","Value":"val3"}] - |}, - |"Task Metrics":{ - | "Host Name":"localhost","Executor Deserialize Time":300,"Executor Run Time":400, - | "Result Size":500,"JVM GC Time":600,"Result Serialization Time":700, - | "Memory Bytes Spilled":800,"Disk Bytes Spilled":0, - | "Shuffle Read Metrics":{ - | "Shuffle Finish Time":900, - | "Remote Blocks Fetched":800, - | "Local Blocks Fetched":700, - | "Fetch Wait Time":900, - | "Remote Bytes Read":1000 + |{ + | "Event": "SparkListenerTaskEnd", + | "Stage ID": 1, + | "Stage Attempt ID": 0, + | "Task Type": "ShuffleMapTask", + | "Task End Reason": { + | "Reason": "Success" | }, - | "Shuffle Write Metrics":{ - | "Shuffle Bytes Written":1200, - | "Shuffle Write Time":1500 + | "Task Info": { + | "Task ID": 123, + | "Index": 234, + | "Attempt": 67, + | "Launch Time": 345, + | "Executor ID": "executor", + | "Host": "your kind sir", + | "Locality": "NODE_LOCAL", + | "Speculative": false, + | "Getting Result Time": 0, + | "Finish Time": 0, + | "Failed": false, + | "Accumulables": [ + | { + | "ID": 1, + | "Name": "Accumulable1", + | "Update": "delta1", + | "Value": "val1" + | }, + | { + | "ID": 2, + | "Name": "Accumulable2", + | "Update": "delta2", + | "Value": "val2" + | }, + | { + | "ID": 3, + | "Name": "Accumulable3", + | "Update": "delta3", + | "Value": "val3" + | } + | ] | }, - | "Updated Blocks":[ - | {"Block ID":"rdd_0_0", - | "Status":{ - | "Storage Level":{ - | "Use Disk":true,"Use Memory":true,"Use Tachyon":false,"Deserialized":false, - | "Replication":2 - | }, - | "Memory Size":0,"Tachyon Size":0,"Disk Size":0 + | "Task Metrics": { + | "Host Name": "localhost", + | "Executor Deserialize Time": 300, + | "Executor Run Time": 400, + | "Result Size": 500, + | "JVM GC Time": 600, + | "Result Serialization Time": 700, + | "Memory Bytes Spilled": 800, + | "Disk Bytes Spilled": 0, + | "Shuffle Read Metrics": { + | "Shuffle Finish Time": 900, + | "Remote Blocks Fetched": 800, + | "Local Blocks Fetched": 700, + | "Fetch Wait Time": 900, + | "Remote Bytes Read": 1000 + | }, + | "Shuffle Write Metrics": { + | "Shuffle Bytes Written": 1200, + | "Shuffle Write Time": 1500 + | }, + | "Updated Blocks": [ + | { + | "Block ID": "rdd_0_0", + | "Status": { + | "Storage Level": { + | "Use Disk": true, + | "Use Memory": true, + | "Use Tachyon": false, + | "Deserialized": false, + | "Replication": 2 + | }, + | "Memory Size": 0, + | "Tachyon Size": 0, + | "Disk Size": 0 + | } | } - | } | ] | } |} @@ -643,80 +806,187 @@ class JsonProtocolSuite extends FunSuite { private val taskEndWithHadoopInputJsonString = """ - |{"Event":"SparkListenerTaskEnd","Stage ID":1,"Task Type":"ShuffleMapTask", - |"Task End Reason":{"Reason":"Success"}, - |"Task Info":{ - | "Task ID":123,"Index":234,"Attempt":67,"Launch Time":345,"Executor ID":"executor", - | "Host":"your kind sir","Locality":"NODE_LOCAL","Speculative":false, - | "Getting Result Time":0,"Finish Time":0,"Failed":false, - | "Accumulables":[{"ID":1,"Name":"Accumulable1","Update":"delta1", - | "Value":"val1"},{"ID":2,"Name":"Accumulable2","Update":"delta2","Value":"val2"}, - | {"ID":3,"Name":"Accumulable3","Update":"delta3","Value":"val3"}] - |}, - |"Task Metrics":{ - | "Host Name":"localhost","Executor Deserialize Time":300,"Executor Run Time":400, - | "Result Size":500,"JVM GC Time":600,"Result Serialization Time":700, - | "Memory Bytes Spilled":800,"Disk Bytes Spilled":0, - | "Shuffle Write Metrics":{"Shuffle Bytes Written":1200,"Shuffle Write Time":1500}, - | "Input Metrics":{"Data Read Method":"Hadoop","Bytes Read":2100}, - | "Updated Blocks":[ - | {"Block ID":"rdd_0_0", - | "Status":{ - | "Storage Level":{ - | "Use Disk":true,"Use Memory":true,"Use Tachyon":false,"Deserialized":false, - | "Replication":2 - | }, - | "Memory Size":0,"Tachyon Size":0,"Disk Size":0 + |{ + | "Event": "SparkListenerTaskEnd", + | "Stage ID": 1, + | "Stage Attempt ID": 0, + | "Task Type": "ShuffleMapTask", + | "Task End Reason": { + | "Reason": "Success" + | }, + | "Task Info": { + | "Task ID": 123, + | "Index": 234, + | "Attempt": 67, + | "Launch Time": 345, + | "Executor ID": "executor", + | "Host": "your kind sir", + | "Locality": "NODE_LOCAL", + | "Speculative": false, + | "Getting Result Time": 0, + | "Finish Time": 0, + | "Failed": false, + | "Accumulables": [ + | { + | "ID": 1, + | "Name": "Accumulable1", + | "Update": "delta1", + | "Value": "val1" + | }, + | { + | "ID": 2, + | "Name": "Accumulable2", + | "Update": "delta2", + | "Value": "val2" + | }, + | { + | "ID": 3, + | "Name": "Accumulable3", + | "Update": "delta3", + | "Value": "val3" | } - | } - | ]} + | ] + | }, + | "Task Metrics": { + | "Host Name": "localhost", + | "Executor Deserialize Time": 300, + | "Executor Run Time": 400, + | "Result Size": 500, + | "JVM GC Time": 600, + | "Result Serialization Time": 700, + | "Memory Bytes Spilled": 800, + | "Disk Bytes Spilled": 0, + | "Shuffle Write Metrics": { + | "Shuffle Bytes Written": 1200, + | "Shuffle Write Time": 1500 + | }, + | "Input Metrics": { + | "Data Read Method": "Hadoop", + | "Bytes Read": 2100 + | }, + | "Updated Blocks": [ + | { + | "Block ID": "rdd_0_0", + | "Status": { + | "Storage Level": { + | "Use Disk": true, + | "Use Memory": true, + | "Use Tachyon": false, + | "Deserialized": false, + | "Replication": 2 + | }, + | "Memory Size": 0, + | "Tachyon Size": 0, + | "Disk Size": 0 + | } + | } + | ] + | } |} """ private val jobStartJsonString = """ - {"Event":"SparkListenerJobStart","Job ID":10,"Stage IDs":[1,2,3,4],"Properties": - {"France":"Paris","Germany":"Berlin","Russia":"Moscow","Ukraine":"Kiev"}} + |{ + | "Event": "SparkListenerJobStart", + | "Job ID": 10, + | "Stage IDs": [ + | 1, + | 2, + | 3, + | 4 + | ], + | "Properties": { + | "France": "Paris", + | "Germany": "Berlin", + | "Russia": "Moscow", + | "Ukraine": "Kiev" + | } + |} """ private val jobEndJsonString = """ - {"Event":"SparkListenerJobEnd","Job ID":20,"Job Result":{"Result":"JobSucceeded"}} + |{ + | "Event": "SparkListenerJobEnd", + | "Job ID": 20, + | "Job Result": { + | "Result": "JobSucceeded" + | } + |} """ private val environmentUpdateJsonString = """ - {"Event":"SparkListenerEnvironmentUpdate","JVM Information":{"GC speed":"9999 objects/s", - "Java home":"Land of coffee"},"Spark Properties":{"Job throughput":"80000 jobs/s, - regardless of job type"},"System Properties":{"Username":"guest","Password":"guest"}, - "Classpath Entries":{"Super library":"/tmp/super_library"}} + |{ + | "Event": "SparkListenerEnvironmentUpdate", + | "JVM Information": { + | "GC speed": "9999 objects/s", + | "Java home": "Land of coffee" + | }, + | "Spark Properties": { + | "Job throughput": "80000 jobs/s, regardless of job type" + | }, + | "System Properties": { + | "Username": "guest", + | "Password": "guest" + | }, + | "Classpath Entries": { + | "Super library": "/tmp/super_library" + | } + |} """ private val blockManagerAddedJsonString = """ - {"Event":"SparkListenerBlockManagerAdded","Block Manager ID":{"Executor ID":"Stars", - "Host":"In your multitude...","Port":300,"Netty Port":400},"Maximum Memory":500} + |{ + | "Event": "SparkListenerBlockManagerAdded", + | "Block Manager ID": { + | "Executor ID": "Stars", + | "Host": "In your multitude...", + | "Port": 300, + | "Netty Port": 400 + | }, + | "Maximum Memory": 500 + |} """ private val blockManagerRemovedJsonString = """ - {"Event":"SparkListenerBlockManagerRemoved","Block Manager ID":{"Executor ID":"Scarce", - "Host":"to be counted...","Port":100,"Netty Port":200}} + |{ + | "Event": "SparkListenerBlockManagerRemoved", + | "Block Manager ID": { + | "Executor ID": "Scarce", + | "Host": "to be counted...", + | "Port": 100, + | "Netty Port": 200 + | } + |} """ private val unpersistRDDJsonString = """ - {"Event":"SparkListenerUnpersistRDD","RDD ID":12345} + |{ + | "Event": "SparkListenerUnpersistRDD", + | "RDD ID": 12345 + |} """ private val applicationStartJsonString = """ - {"Event":"SparkListenerApplicationStart","App Name":"The winner of all","Timestamp":42, - "User":"Garfield"} + |{ + | "Event": "SparkListenerApplicationStart", + | "App Name": "The winner of all", + | "Timestamp": 42, + | "User": "Garfield" + |} """ private val applicationEndJsonString = """ - {"Event":"SparkListenerApplicationEnd","Timestamp":42} + |{ + | "Event": "SparkListenerApplicationEnd", + | "Timestamp": 42 + |} """ } From 5e57089a8eefd0939089a26e57b96f08e75968f6 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Wed, 20 Aug 2014 16:56:13 -0700 Subject: [PATCH 219/231] Revert "[maven-release-plugin] prepare for next development iteration" This reverts commit c204a742a9eb9d3fd318e0f059bd00cbfb8b2c14. --- assembly/pom.xml | 2 +- bagel/pom.xml | 2 +- core/pom.xml | 2 +- examples/pom.xml | 2 +- external/flume-sink/pom.xml | 2 +- external/flume/pom.xml | 2 +- external/kafka/pom.xml | 2 +- external/mqtt/pom.xml | 2 +- external/twitter/pom.xml | 2 +- external/zeromq/pom.xml | 2 +- extras/kinesis-asl/pom.xml | 2 +- extras/spark-ganglia-lgpl/pom.xml | 2 +- graphx/pom.xml | 2 +- mllib/pom.xml | 2 +- pom.xml | 4 ++-- repl/pom.xml | 2 +- sql/catalyst/pom.xml | 2 +- sql/core/pom.xml | 2 +- sql/hive-thriftserver/pom.xml | 2 +- sql/hive/pom.xml | 2 +- streaming/pom.xml | 2 +- tools/pom.xml | 2 +- yarn/pom.xml | 2 +- yarn/stable/pom.xml | 2 +- 24 files changed, 25 insertions(+), 25 deletions(-) diff --git a/assembly/pom.xml b/assembly/pom.xml index 4709b7dbddfea..799f8d9b3c815 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.1-SNAPSHOT + 1.1.0 ../pom.xml diff --git a/bagel/pom.xml b/bagel/pom.xml index f29540b239c73..8eec7e5dd23b5 100644 --- a/bagel/pom.xml +++ b/bagel/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.1-SNAPSHOT + 1.1.0 ../pom.xml diff --git a/core/pom.xml b/core/pom.xml index debc4dd703d9a..83e60268afbd2 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.1-SNAPSHOT + 1.1.0 ../pom.xml diff --git a/examples/pom.xml b/examples/pom.xml index f35d3d6a788e3..9bde90ed116e1 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.1-SNAPSHOT + 1.1.0 ../pom.xml diff --git a/external/flume-sink/pom.xml b/external/flume-sink/pom.xml index 7f1172ec2092d..9d8a36d9d5b82 100644 --- a/external/flume-sink/pom.xml +++ b/external/flume-sink/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.1-SNAPSHOT + 1.1.0 ../../pom.xml diff --git a/external/flume/pom.xml b/external/flume/pom.xml index b127136e3f5a0..6563f4d73da01 100644 --- a/external/flume/pom.xml +++ b/external/flume/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.1-SNAPSHOT + 1.1.0 ../../pom.xml diff --git a/external/kafka/pom.xml b/external/kafka/pom.xml index 5123d0554639c..e3df55355c8d5 100644 --- a/external/kafka/pom.xml +++ b/external/kafka/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.1-SNAPSHOT + 1.1.0 ../../pom.xml diff --git a/external/mqtt/pom.xml b/external/mqtt/pom.xml index 9c00bfc8429a4..1f9e52b19b24e 100644 --- a/external/mqtt/pom.xml +++ b/external/mqtt/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.1-SNAPSHOT + 1.1.0 ../../pom.xml diff --git a/external/twitter/pom.xml b/external/twitter/pom.xml index 1b9ef4af0c2ed..d28741a9524dc 100644 --- a/external/twitter/pom.xml +++ b/external/twitter/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.1-SNAPSHOT + 1.1.0 ../../pom.xml diff --git a/external/zeromq/pom.xml b/external/zeromq/pom.xml index 60292a2683212..7e7609985d2c0 100644 --- a/external/zeromq/pom.xml +++ b/external/zeromq/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.1-SNAPSHOT + 1.1.0 ../../pom.xml diff --git a/extras/kinesis-asl/pom.xml b/extras/kinesis-asl/pom.xml index 58b995c5e7005..4d0eac8956955 100644 --- a/extras/kinesis-asl/pom.xml +++ b/extras/kinesis-asl/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent - 1.1.1-SNAPSHOT + 1.1.0 ../../pom.xml diff --git a/extras/spark-ganglia-lgpl/pom.xml b/extras/spark-ganglia-lgpl/pom.xml index 02c9676fb086a..44f50aeefa62e 100644 --- a/extras/spark-ganglia-lgpl/pom.xml +++ b/extras/spark-ganglia-lgpl/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent - 1.1.1-SNAPSHOT + 1.1.0 ../../pom.xml diff --git a/graphx/pom.xml b/graphx/pom.xml index 656478583fac2..c677f361cf1f2 100644 --- a/graphx/pom.xml +++ b/graphx/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.1-SNAPSHOT + 1.1.0 ../pom.xml diff --git a/mllib/pom.xml b/mllib/pom.xml index 74f528f030987..873f65105b1c5 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.1-SNAPSHOT + 1.1.0 ../pom.xml diff --git a/pom.xml b/pom.xml index 1479326af0ed9..359f102bf8f7d 100644 --- a/pom.xml +++ b/pom.xml @@ -25,7 +25,7 @@ org.apache.spark spark-parent - 1.1.1-SNAPSHOT + 1.1.0 pom Spark Project Parent POM http://spark.apache.org/ @@ -40,7 +40,7 @@ scm:git:git@github.com:apache/spark.git scm:git:https://git-wip-us.apache.org/repos/asf/spark.git scm:git:git@github.com:apache/spark.git - HEAD + v1.1.0-snapshot1 diff --git a/repl/pom.xml b/repl/pom.xml index 8748ada36f57a..b7458eeb270dd 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.1-SNAPSHOT + 1.1.0 ../pom.xml diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index e2356381c07fb..7c9e5b284e0d9 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.1-SNAPSHOT + 1.1.0 ../../pom.xml diff --git a/sql/core/pom.xml b/sql/core/pom.xml index 3efea9ab8b247..d797753f12151 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.1-SNAPSHOT + 1.1.0 ../../pom.xml diff --git a/sql/hive-thriftserver/pom.xml b/sql/hive-thriftserver/pom.xml index c264ff4ec92e5..d75d2e514544d 100644 --- a/sql/hive-thriftserver/pom.xml +++ b/sql/hive-thriftserver/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.1-SNAPSHOT + 1.1.0 ../../pom.xml diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index 1e689e6d6dcf2..dc5cfbac8212d 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.1-SNAPSHOT + 1.1.0 ../../pom.xml diff --git a/streaming/pom.xml b/streaming/pom.xml index c0ce0d7c7478d..31c096380a7c1 100644 --- a/streaming/pom.xml +++ b/streaming/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.1-SNAPSHOT + 1.1.0 ../pom.xml diff --git a/tools/pom.xml b/tools/pom.xml index c601fd5fbbee2..ce3629443ed98 100644 --- a/tools/pom.xml +++ b/tools/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent - 1.1.1-SNAPSHOT + 1.1.0 ../pom.xml diff --git a/yarn/pom.xml b/yarn/pom.xml index 18f27b827ff1a..274be3a563641 100644 --- a/yarn/pom.xml +++ b/yarn/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent - 1.1.1-SNAPSHOT + 1.1.0 ../pom.xml diff --git a/yarn/stable/pom.xml b/yarn/stable/pom.xml index 2ba3baf0e3b2e..64fb00ac71b60 100644 --- a/yarn/stable/pom.xml +++ b/yarn/stable/pom.xml @@ -20,7 +20,7 @@ org.apache.spark yarn-parent_2.10 - 1.1.1-SNAPSHOT + 1.1.0 ../pom.xml From 2bcabcd6a732fa235d2b1279830809f394521fab Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Wed, 20 Aug 2014 16:56:18 -0700 Subject: [PATCH 220/231] Revert "[maven-release-plugin] prepare release v1.1.0-snapshot1" This reverts commit d428d88418d385d1d04e1b0adcb6b068efe9c7b0. --- assembly/pom.xml | 6 +++--- bagel/pom.xml | 2 +- core/pom.xml | 2 +- examples/pom.xml | 2 +- external/flume-sink/pom.xml | 2 +- external/flume/pom.xml | 2 +- external/kafka/pom.xml | 2 +- external/mqtt/pom.xml | 2 +- external/twitter/pom.xml | 2 +- external/zeromq/pom.xml | 2 +- extras/kinesis-asl/pom.xml | 2 +- extras/spark-ganglia-lgpl/pom.xml | 2 +- graphx/pom.xml | 2 +- mllib/pom.xml | 2 +- pom.xml | 9 +++++---- repl/pom.xml | 2 +- sql/catalyst/pom.xml | 5 +++-- sql/core/pom.xml | 5 +++-- sql/hive-thriftserver/pom.xml | 5 +++-- sql/hive/pom.xml | 5 +++-- streaming/pom.xml | 2 +- tools/pom.xml | 2 +- yarn/pom.xml | 2 +- yarn/stable/pom.xml | 2 +- 24 files changed, 38 insertions(+), 33 deletions(-) diff --git a/assembly/pom.xml b/assembly/pom.xml index 799f8d9b3c815..9fbb037115db3 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.0 + 1.1.0-SNAPSHOT ../pom.xml @@ -124,8 +124,8 @@ log4j.properties - - + +
    diff --git a/bagel/pom.xml b/bagel/pom.xml index 8eec7e5dd23b5..bd51b112e26fa 100644 --- a/bagel/pom.xml +++ b/bagel/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.0 + 1.1.0-SNAPSHOT ../pom.xml diff --git a/core/pom.xml b/core/pom.xml index 83e60268afbd2..6d8be37037729 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.0 + 1.1.0-SNAPSHOT ../pom.xml diff --git a/examples/pom.xml b/examples/pom.xml index 9bde90ed116e1..8c4c128bb484d 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.0 + 1.1.0-SNAPSHOT ../pom.xml diff --git a/external/flume-sink/pom.xml b/external/flume-sink/pom.xml index 9d8a36d9d5b82..0c68defa5e101 100644 --- a/external/flume-sink/pom.xml +++ b/external/flume-sink/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.0 + 1.1.0-SNAPSHOT ../../pom.xml diff --git a/external/flume/pom.xml b/external/flume/pom.xml index 6563f4d73da01..c532705f3950c 100644 --- a/external/flume/pom.xml +++ b/external/flume/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.0 + 1.1.0-SNAPSHOT ../../pom.xml diff --git a/external/kafka/pom.xml b/external/kafka/pom.xml index e3df55355c8d5..4e2275ab238f7 100644 --- a/external/kafka/pom.xml +++ b/external/kafka/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.0 + 1.1.0-SNAPSHOT ../../pom.xml diff --git a/external/mqtt/pom.xml b/external/mqtt/pom.xml index 1f9e52b19b24e..dc48a08c93de2 100644 --- a/external/mqtt/pom.xml +++ b/external/mqtt/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.0 + 1.1.0-SNAPSHOT ../../pom.xml diff --git a/external/twitter/pom.xml b/external/twitter/pom.xml index d28741a9524dc..b93ad016f84f0 100644 --- a/external/twitter/pom.xml +++ b/external/twitter/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.0 + 1.1.0-SNAPSHOT ../../pom.xml diff --git a/external/zeromq/pom.xml b/external/zeromq/pom.xml index 7e7609985d2c0..22c1fff23d9a2 100644 --- a/external/zeromq/pom.xml +++ b/external/zeromq/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.0 + 1.1.0-SNAPSHOT ../../pom.xml diff --git a/extras/kinesis-asl/pom.xml b/extras/kinesis-asl/pom.xml index 4d0eac8956955..a54b34235dfb4 100644 --- a/extras/kinesis-asl/pom.xml +++ b/extras/kinesis-asl/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent - 1.1.0 + 1.1.0-SNAPSHOT ../../pom.xml diff --git a/extras/spark-ganglia-lgpl/pom.xml b/extras/spark-ganglia-lgpl/pom.xml index 44f50aeefa62e..a5b162a0482e4 100644 --- a/extras/spark-ganglia-lgpl/pom.xml +++ b/extras/spark-ganglia-lgpl/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent - 1.1.0 + 1.1.0-SNAPSHOT ../../pom.xml diff --git a/graphx/pom.xml b/graphx/pom.xml index c677f361cf1f2..6dd52fc618b1e 100644 --- a/graphx/pom.xml +++ b/graphx/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.0 + 1.1.0-SNAPSHOT ../pom.xml diff --git a/mllib/pom.xml b/mllib/pom.xml index 873f65105b1c5..c7a1e2ae75c84 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.0 + 1.1.0-SNAPSHOT ../pom.xml diff --git a/pom.xml b/pom.xml index 359f102bf8f7d..b8df3d025cfbf 100644 --- a/pom.xml +++ b/pom.xml @@ -16,7 +16,8 @@ ~ limitations under the License. --> - + 4.0.0 org.apache @@ -25,7 +26,7 @@ org.apache.spark spark-parent - 1.1.0 + 1.1.0-SNAPSHOT pom Spark Project Parent POM http://spark.apache.org/ @@ -40,7 +41,7 @@ scm:git:git@github.com:apache/spark.git scm:git:https://git-wip-us.apache.org/repos/asf/spark.git scm:git:git@github.com:apache/spark.git - v1.1.0-snapshot1 + HEAD @@ -879,7 +880,7 @@ . ${project.build.directory}/SparkTestSuite.txt -Xmx3g -XX:MaxPermSize=${MaxPermGen} -XX:ReservedCodeCacheSize=512m - + true ${session.executionRootDirectory} diff --git a/repl/pom.xml b/repl/pom.xml index b7458eeb270dd..68f4504450778 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.0 + 1.1.0-SNAPSHOT ../pom.xml diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index 7c9e5b284e0d9..58d44e7923bee 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -16,12 +16,13 @@ ~ limitations under the License. --> - + 4.0.0 org.apache.spark spark-parent - 1.1.0 + 1.1.0-SNAPSHOT ../../pom.xml diff --git a/sql/core/pom.xml b/sql/core/pom.xml index d797753f12151..c8016e41256d5 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -16,12 +16,13 @@ ~ limitations under the License. --> - + 4.0.0 org.apache.spark spark-parent - 1.1.0 + 1.1.0-SNAPSHOT ../../pom.xml diff --git a/sql/hive-thriftserver/pom.xml b/sql/hive-thriftserver/pom.xml index d75d2e514544d..c6f60c18804a4 100644 --- a/sql/hive-thriftserver/pom.xml +++ b/sql/hive-thriftserver/pom.xml @@ -16,12 +16,13 @@ ~ limitations under the License. --> - + 4.0.0 org.apache.spark spark-parent - 1.1.0 + 1.1.0-SNAPSHOT ../../pom.xml diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index dc5cfbac8212d..30ff277e67c88 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -16,12 +16,13 @@ ~ limitations under the License. --> - + 4.0.0 org.apache.spark spark-parent - 1.1.0 + 1.1.0-SNAPSHOT ../../pom.xml diff --git a/streaming/pom.xml b/streaming/pom.xml index 31c096380a7c1..1072f74aea0d9 100644 --- a/streaming/pom.xml +++ b/streaming/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.0 + 1.1.0-SNAPSHOT ../pom.xml diff --git a/tools/pom.xml b/tools/pom.xml index ce3629443ed98..97abb6b2b63e0 100644 --- a/tools/pom.xml +++ b/tools/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent - 1.1.0 + 1.1.0-SNAPSHOT ../pom.xml diff --git a/yarn/pom.xml b/yarn/pom.xml index 274be3a563641..3faaf053634d6 100644 --- a/yarn/pom.xml +++ b/yarn/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent - 1.1.0 + 1.1.0-SNAPSHOT ../pom.xml diff --git a/yarn/stable/pom.xml b/yarn/stable/pom.xml index 64fb00ac71b60..b6c8456d06684 100644 --- a/yarn/stable/pom.xml +++ b/yarn/stable/pom.xml @@ -20,7 +20,7 @@ org.apache.spark yarn-parent_2.10 - 1.1.0 + 1.1.0-SNAPSHOT ../pom.xml From f8bcb12c1820402824a8d65dcbb60189e08679c6 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Wed, 20 Aug 2014 17:07:39 -0700 Subject: [PATCH 221/231] [SPARK-3140] Clarify confusing PySpark exception message We read the py4j port from the stdout of the `bin/spark-submit` subprocess. If there is interference in stdout (e.g. a random echo in `spark-submit`), we throw an exception with a warning message. We do not, however, distinguish between this case from the case where no stdout is produced at all. I wasted a non-trivial amount of time being baffled by this exception in search of places where I print random whitespace (in vain, of course). A clearer exception message that distinguishes between these cases will prevent similar headaches that I have gone through. Author: Andrew Or Closes #2067 from andrewor14/python-exception and squashes the following commits: 742f823 [Andrew Or] Further clarify warning messages e96a7a0 [Andrew Or] Distinguish between unexpected output and no output at all (cherry picked from commit ba3c730e35bcdb662396955c3cc6f7de628034c8) Signed-off-by: Andrew Or --- python/pyspark/java_gateway.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py index c7f7c1fe591b0..6f4f62f23bc4d 100644 --- a/python/pyspark/java_gateway.py +++ b/python/pyspark/java_gateway.py @@ -54,12 +54,19 @@ def preexec_func(): gateway_port = proc.stdout.readline() gateway_port = int(gateway_port) except ValueError: + # Grab the remaining lines of stdout (stdout, _) = proc.communicate() exit_code = proc.poll() error_msg = "Launching GatewayServer failed" - error_msg += " with exit code %d! " % exit_code if exit_code else "! " - error_msg += "(Warning: unexpected output detected.)\n\n" - error_msg += gateway_port + stdout + error_msg += " with exit code %d!\n" % exit_code if exit_code else "!\n" + error_msg += "Warning: Expected GatewayServer to output a port, but found " + if gateway_port == "" and stdout == "": + error_msg += "no output.\n" + else: + error_msg += "the following:\n\n" + error_msg += "--------------------------------------------------------------\n" + error_msg += gateway_port + stdout + error_msg += "--------------------------------------------------------------\n" raise Exception(error_msg) # Create a thread to echo output from the GatewayServer, which is required From 1af68caf68d6d34f588723184dc2f75d7578b1d9 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Wed, 20 Aug 2014 17:41:36 -0700 Subject: [PATCH 222/231] [SPARK-3143][MLLIB] add tf-idf user guide Moved TF-IDF before Word2Vec because the former is more basic. I also added a link for Word2Vec. atalwalkar Author: Xiangrui Meng Closes #2061 from mengxr/tfidf-doc and squashes the following commits: ca04c70 [Xiangrui Meng] address comments a5ea4b4 [Xiangrui Meng] add tf-idf user guide (cherry picked from commit e1571874f26c1df2dfd5ac2959612372716cd2d8) Signed-off-by: Xiangrui Meng --- docs/mllib-feature-extraction.md | 83 ++++++++++++++++++++++++++++++-- 1 file changed, 80 insertions(+), 3 deletions(-) diff --git a/docs/mllib-feature-extraction.md b/docs/mllib-feature-extraction.md index 4b3cb715c58c7..2031b96235ee9 100644 --- a/docs/mllib-feature-extraction.md +++ b/docs/mllib-feature-extraction.md @@ -7,9 +7,88 @@ displayTitle: MLlib - Feature Extraction * Table of contents {:toc} + +## TF-IDF + +[Term frequency-inverse document frequency (TF-IDF)](http://en.wikipedia.org/wiki/Tf%E2%80%93idf) is a feature +vectorization method widely used in text mining to reflect the importance of a term to a document in the corpus. +Denote a term by `$t$`, a document by `$d$`, and the corpus by `$D$`. +Term frequency `$TF(t, d)$` is the number of times that term `$t$` appears in document `$d$`, +while document frequency `$DF(t, D)$` is the number of documents that contains term `$t$`. +If we only use term frequency to measure the importance, it is very easy to over-emphasize terms that +appear very often but carry little information about the document, e.g., "a", "the", and "of". +If a term appears very often across the corpus, it means it doesn't carry special information about +a particular document. +Inverse document frequency is a numerical measure of how much information a term provides: +`\[ +IDF(t, D) = \log \frac{|D| + 1}{DF(t, D) + 1}, +\]` +where `$|D|$` is the total number of documents in the corpus. +Since logarithm is used, if a term appears in all documents, its IDF value becomes 0. +Note that a smoothing term is applied to avoid dividing by zero for terms outside the corpus. +The TF-IDF measure is simply the product of TF and IDF: +`\[ +TFIDF(t, d, D) = TF(t, d) \cdot IDF(t, D). +\]` +There are several variants on the definition of term frequency and document frequency. +In MLlib, we separate TF and IDF to make them flexible. + +Our implementation of term frequency utilizes the +[hashing trick](http://en.wikipedia.org/wiki/Feature_hashing). +A raw feature is mapped into an index (term) by applying a hash function. +Then term frequencies are calculated based on the mapped indices. +This approach avoids the need to compute a global term-to-index map, +which can be expensive for a large corpus, but it suffers from potential hash collisions, +where different raw features may become the same term after hashing. +To reduce the chance of collision, we can increase the target feature dimension, i.e., +the number of buckets of the hash table. +The default feature dimension is `$2^{20} = 1,048,576$`. + +**Note:** MLlib doesn't provide tools for text segmentation. +We refer users to the [Stanford NLP Group](http://nlp.stanford.edu/) and +[scalanlp/chalk](https://github.com/scalanlp/chalk). + +
    +
    + +TF and IDF are implemented in [HashingTF](api/scala/index.html#org.apache.spark.mllib.feature.HashingTF) +and [IDF](api/scala/index.html#org.apache.spark.mllib.feature.IDF). +`HashingTF` takes an `RDD[Iterable[_]]` as the input. +Each record could be an iterable of strings or other types. + +{% highlight scala %} +import org.apache.spark.rdd.RDD +import org.apache.spark.SparkContext +import org.apache.spark.mllib.feature.HashingTF +import org.apache.spark.mllib.linalg.Vector + +val sc: SparkContext = ... + +// Load documents (one per line). +val documents: RDD[Seq[String]] = sc.textFile("...").map(_.split(" ").toSeq) + +val hashingTF = new HashingTF() +val tf: RDD[Vector] = hasingTF.transform(documents) +{% endhighlight %} + +While applying `HashingTF` only needs a single pass to the data, applying `IDF` needs two passes: +first to compute the IDF vector and second to scale the term frequencies by IDF. + +{% highlight scala %} +import org.apache.spark.mllib.feature.IDF + +// ... continue from the previous example +tf.cache() +val idf = new IDF().fit(tf) +val tfidf: RDD[Vector] = idf.transform(tf) +{% endhighlight %} +
    +
    + ## Word2Vec -Word2Vec computes distributed vector representation of words. The main advantage of the distributed +[Word2Vec](https://code.google.com/p/word2vec/) computes distributed vector representation of words. +The main advantage of the distributed representations is that similar words are close in the vector space, which makes generalization to novel patterns easier and model estimation more robust. Distributed vector representation is showed to be useful in many natural language processing applications such as named entity @@ -69,5 +148,3 @@ for((synonym, cosineSimilarity) <- synonyms) { {% endhighlight %} - -## TFIDF \ No newline at end of file From eba399b3c6768f5106cbc17752630fa81d9cdce4 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Wed, 20 Aug 2014 17:47:39 -0700 Subject: [PATCH 223/231] [SPARK-2843][MLLIB] add a section about regularization parameter in ALS atalwalkar srowen Author: Xiangrui Meng Closes #2064 from mengxr/als-doc and squashes the following commits: b2e20ab [Xiangrui Meng] introduced -> discussed 98abdd7 [Xiangrui Meng] add reference 339bd08 [Xiangrui Meng] add a section about regularization parameter in ALS (cherry picked from commit e0f946265b9ea5bc48849cf7794c2c03d5e29fba) Signed-off-by: Xiangrui Meng --- docs/mllib-collaborative-filtering.md | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/docs/mllib-collaborative-filtering.md b/docs/mllib-collaborative-filtering.md index ab10b2f01f87b..d5c539db791be 100644 --- a/docs/mllib-collaborative-filtering.md +++ b/docs/mllib-collaborative-filtering.md @@ -43,6 +43,17 @@ level of confidence in observed user preferences, rather than explicit ratings g model then tries to find latent factors that can be used to predict the expected preference of a user for an item. +### Scaling of the regularization parameter + +Since v1.1, we scale the regularization parameter `lambda` in solving each least squares problem by +the number of ratings the user generated in updating user factors, +or the number of ratings the product received in updating product factors. +This approach is named "ALS-WR" and discussed in the paper +"[Large-Scale Parallel Collaborative Filtering for the Netflix Prize](http://dx.doi.org/10.1007/978-3-540-68880-8_32)". +It makes `lambda` less dependent on the scale of the dataset. +So we can apply the best parameter learned from a sampled subset to the full dataset +and expect similar performance. + ## Examples
    From 3f91e9dc2563f3c5c473c781bd3078cc620ff880 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Wed, 6 Aug 2014 16:34:53 -0700 Subject: [PATCH 224/231] [HOTFIX][Streaming] Handle port collisions in flume polling test This is failing my tests in #1777. @tdas Author: Andrew Or Closes #1803 from andrewor14/fix-flaky-streaming-test and squashes the following commits: ea11a03 [Andrew Or] Catch all exceptions caused by BindExceptions 54a0ca0 [Andrew Or] Merge branch 'master' of github.com:apache/spark into fix-flaky-streaming-test 664095c [Andrew Or] Tone down bind exception message af3ddc9 [Andrew Or] Handle port collisions in flume polling test --- .../flume/FlumePollingStreamSuite.scala | 32 ++++++++++++++++++- 1 file changed, 31 insertions(+), 1 deletion(-) diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala index 2e4ac7cfbf263..e3a5bdcd24868 100644 --- a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala +++ b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala @@ -35,6 +35,7 @@ import org.apache.spark.streaming.dstream.ReceiverInputDStream import org.apache.spark.streaming.util.ManualClock import org.apache.spark.streaming.{TestSuiteBase, TestOutputStream, StreamingContext} import org.apache.spark.streaming.flume.sink._ +import org.apache.spark.util.Utils class FlumePollingStreamSuite extends TestSuiteBase { @@ -45,8 +46,37 @@ class FlumePollingStreamSuite extends TestSuiteBase { val eventsPerBatch = 100 val totalEventsPerChannel = batchCount * eventsPerBatch val channelCapacity = 5000 + val maxAttempts = 5 test("flume polling test") { + testMultipleTimes(testFlumePolling) + } + + test("flume polling test multiple hosts") { + testMultipleTimes(testFlumePollingMultipleHost) + } + + /** + * Run the given test until no more java.net.BindException's are thrown. + * Do this only up to a certain attempt limit. + */ + private def testMultipleTimes(test: () => Unit): Unit = { + var testPassed = false + var attempt = 0 + while (!testPassed && attempt < maxAttempts) { + try { + test() + testPassed = true + } catch { + case e: Exception if Utils.isBindCollision(e) => + logWarning("Exception when running flume polling test: " + e) + attempt += 1 + } + } + assert(testPassed, s"Test failed after $attempt attempts!") + } + + private def testFlumePolling(): Unit = { val testPort = getTestPort // Set up the streaming context and input streams val ssc = new StreamingContext(conf, batchDuration) @@ -80,7 +110,7 @@ class FlumePollingStreamSuite extends TestSuiteBase { channel.stop() } - test("flume polling test multiple hosts") { + private def testFlumePollingMultipleHost(): Unit = { val testPort = getTestPort // Set up the streaming context and input streams val ssc = new StreamingContext(conf, batchDuration) From 44856654c81ceb92ef6380691027744d4bf76589 Mon Sep 17 00:00:00 2001 From: Hari Shreedharan Date: Sun, 17 Aug 2014 19:50:31 -0700 Subject: [PATCH 225/231] [HOTFIX][STREAMING] Allow the JVM/Netty to decide which port to bind to in Flume Polling Tests. Author: Hari Shreedharan Closes #1820 from harishreedharan/use-free-ports and squashes the following commits: b939067 [Hari Shreedharan] Remove unused import. 67856a8 [Hari Shreedharan] Remove findFreePort. 0ea51d1 [Hari Shreedharan] Make some changes to getPort to use map on the serverOpt. 1fb0283 [Hari Shreedharan] Merge branch 'master' of https://github.com/apache/spark into use-free-ports b351651 [Hari Shreedharan] Allow Netty to choose port, and query it to decide the port to bind to. Leaving findFreePort as is, if other tests want to use it at some point. e6c9620 [Hari Shreedharan] Making sure the second sink uses the correct port. 11c340d [Hari Shreedharan] Add info about race condition to scaladoc. e89d135 [Hari Shreedharan] Adding Scaladoc. 6013bb0 [Hari Shreedharan] [STREAMING] Find free ports to use before attempting to create Flume Sink in Flume Polling Suite --- .../streaming/flume/sink/SparkSink.scala | 8 +++ .../flume/FlumePollingStreamSuite.scala | 55 +++++++++---------- 2 files changed, 34 insertions(+), 29 deletions(-) diff --git a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSink.scala b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSink.scala index 1a61b36910a95..98ae7d783aec8 100644 --- a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSink.scala +++ b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSink.scala @@ -130,6 +130,14 @@ class SparkSink extends AbstractSink with Logging with Configurable { blockingLatch.await() Status.BACKOFF } + + private[flume] def getPort(): Int = { + serverOpt + .map(_.getPort) + .getOrElse( + throw new RuntimeException("Server was not started!") + ) + } } /** diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala index e3a5bdcd24868..32a19787a28e1 100644 --- a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala +++ b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala @@ -22,6 +22,8 @@ import java.net.InetSocketAddress import java.util.concurrent.{Callable, ExecutorCompletionService, Executors} import java.util.Random +import org.apache.spark.TestUtils + import scala.collection.JavaConversions._ import scala.collection.mutable.{SynchronizedBuffer, ArrayBuffer} @@ -39,9 +41,6 @@ import org.apache.spark.util.Utils class FlumePollingStreamSuite extends TestSuiteBase { - val random = new Random() - /** Return a port in the ephemeral range. */ - def getTestPort = random.nextInt(16382) + 49152 val batchCount = 5 val eventsPerBatch = 100 val totalEventsPerChannel = batchCount * eventsPerBatch @@ -77,17 +76,6 @@ class FlumePollingStreamSuite extends TestSuiteBase { } private def testFlumePolling(): Unit = { - val testPort = getTestPort - // Set up the streaming context and input streams - val ssc = new StreamingContext(conf, batchDuration) - val flumeStream: ReceiverInputDStream[SparkFlumeEvent] = - FlumeUtils.createPollingStream(ssc, Seq(new InetSocketAddress("localhost", testPort)), - StorageLevel.MEMORY_AND_DISK, eventsPerBatch, 1) - val outputBuffer = new ArrayBuffer[Seq[SparkFlumeEvent]] - with SynchronizedBuffer[Seq[SparkFlumeEvent]] - val outputStream = new TestOutputStream(flumeStream, outputBuffer) - outputStream.register() - // Start the channel and sink. val context = new Context() context.put("capacity", channelCapacity.toString) @@ -98,10 +86,19 @@ class FlumePollingStreamSuite extends TestSuiteBase { val sink = new SparkSink() context.put(SparkSinkConfig.CONF_HOSTNAME, "localhost") - context.put(SparkSinkConfig.CONF_PORT, String.valueOf(testPort)) + context.put(SparkSinkConfig.CONF_PORT, String.valueOf(0)) Configurables.configure(sink, context) sink.setChannel(channel) sink.start() + // Set up the streaming context and input streams + val ssc = new StreamingContext(conf, batchDuration) + val flumeStream: ReceiverInputDStream[SparkFlumeEvent] = + FlumeUtils.createPollingStream(ssc, Seq(new InetSocketAddress("localhost", sink.getPort())), + StorageLevel.MEMORY_AND_DISK, eventsPerBatch, 1) + val outputBuffer = new ArrayBuffer[Seq[SparkFlumeEvent]] + with SynchronizedBuffer[Seq[SparkFlumeEvent]] + val outputStream = new TestOutputStream(flumeStream, outputBuffer) + outputStream.register() ssc.start() writeAndVerify(Seq(channel), ssc, outputBuffer) @@ -111,18 +108,6 @@ class FlumePollingStreamSuite extends TestSuiteBase { } private def testFlumePollingMultipleHost(): Unit = { - val testPort = getTestPort - // Set up the streaming context and input streams - val ssc = new StreamingContext(conf, batchDuration) - val addresses = Seq(testPort, testPort + 1).map(new InetSocketAddress("localhost", _)) - val flumeStream: ReceiverInputDStream[SparkFlumeEvent] = - FlumeUtils.createPollingStream(ssc, addresses, StorageLevel.MEMORY_AND_DISK, - eventsPerBatch, 5) - val outputBuffer = new ArrayBuffer[Seq[SparkFlumeEvent]] - with SynchronizedBuffer[Seq[SparkFlumeEvent]] - val outputStream = new TestOutputStream(flumeStream, outputBuffer) - outputStream.register() - // Start the channel and sink. val context = new Context() context.put("capacity", channelCapacity.toString) @@ -136,17 +121,29 @@ class FlumePollingStreamSuite extends TestSuiteBase { val sink = new SparkSink() context.put(SparkSinkConfig.CONF_HOSTNAME, "localhost") - context.put(SparkSinkConfig.CONF_PORT, String.valueOf(testPort)) + context.put(SparkSinkConfig.CONF_PORT, String.valueOf(0)) Configurables.configure(sink, context) sink.setChannel(channel) sink.start() val sink2 = new SparkSink() context.put(SparkSinkConfig.CONF_HOSTNAME, "localhost") - context.put(SparkSinkConfig.CONF_PORT, String.valueOf(testPort + 1)) + context.put(SparkSinkConfig.CONF_PORT, String.valueOf(0)) Configurables.configure(sink2, context) sink2.setChannel(channel2) sink2.start() + + // Set up the streaming context and input streams + val ssc = new StreamingContext(conf, batchDuration) + val addresses = Seq(sink.getPort(), sink2.getPort()).map(new InetSocketAddress("localhost", _)) + val flumeStream: ReceiverInputDStream[SparkFlumeEvent] = + FlumeUtils.createPollingStream(ssc, addresses, StorageLevel.MEMORY_AND_DISK, + eventsPerBatch, 5) + val outputBuffer = new ArrayBuffer[Seq[SparkFlumeEvent]] + with SynchronizedBuffer[Seq[SparkFlumeEvent]] + val outputStream = new TestOutputStream(flumeStream, outputBuffer) + outputStream.register() + ssc.start() writeAndVerify(Seq(channel, channel2), ssc, outputBuffer) assertChannelIsEmpty(channel) From 1d5e84a99076d3e0168dd2f4626c7911e7ba49e7 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Wed, 20 Aug 2014 22:24:22 -0700 Subject: [PATCH 226/231] HOTFIX:Temporarily removing flume sink test in 1.1 branch --- .../streaming/flume/sink/SparkSinkSuite.scala | 204 ------------------ 1 file changed, 204 deletions(-) delete mode 100644 external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala diff --git a/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala b/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala deleted file mode 100644 index 44b27edf85ce8..0000000000000 --- a/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala +++ /dev/null @@ -1,204 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.streaming.flume.sink - -import java.net.InetSocketAddress -import java.util.concurrent.atomic.AtomicInteger -import java.util.concurrent.{TimeUnit, CountDownLatch, Executors} - -import scala.collection.JavaConversions._ -import scala.concurrent.{ExecutionContext, Future} -import scala.util.{Failure, Success} - -import com.google.common.util.concurrent.ThreadFactoryBuilder -import org.apache.avro.ipc.NettyTransceiver -import org.apache.avro.ipc.specific.SpecificRequestor -import org.apache.flume.Context -import org.apache.flume.channel.MemoryChannel -import org.apache.flume.event.EventBuilder -import org.apache.spark.streaming.TestSuiteBase -import org.jboss.netty.channel.socket.nio.NioClientSocketChannelFactory - -class SparkSinkSuite extends TestSuiteBase { - val eventsPerBatch = 1000 - val channelCapacity = 5000 - - test("Success") { - val (channel, sink) = initializeChannelAndSink() - channel.start() - sink.start() - - putEvents(channel, eventsPerBatch) - - val port = sink.getPort - val address = new InetSocketAddress("0.0.0.0", port) - - val (transceiver, client) = getTransceiverAndClient(address, 1)(0) - val events = client.getEventBatch(1000) - client.ack(events.getSequenceNumber) - assert(events.getEvents.size() === 1000) - assertChannelIsEmpty(channel) - sink.stop() - channel.stop() - transceiver.close() - } - - test("Nack") { - val (channel, sink) = initializeChannelAndSink() - channel.start() - sink.start() - putEvents(channel, eventsPerBatch) - - val port = sink.getPort - val address = new InetSocketAddress("0.0.0.0", port) - - val (transceiver, client) = getTransceiverAndClient(address, 1)(0) - val events = client.getEventBatch(1000) - assert(events.getEvents.size() === 1000) - client.nack(events.getSequenceNumber) - assert(availableChannelSlots(channel) === 4000) - sink.stop() - channel.stop() - transceiver.close() - } - - test("Timeout") { - val (channel, sink) = initializeChannelAndSink(Map(SparkSinkConfig - .CONF_TRANSACTION_TIMEOUT -> 1.toString)) - channel.start() - sink.start() - putEvents(channel, eventsPerBatch) - val port = sink.getPort - val address = new InetSocketAddress("0.0.0.0", port) - - val (transceiver, client) = getTransceiverAndClient(address, 1)(0) - val events = client.getEventBatch(1000) - assert(events.getEvents.size() === 1000) - Thread.sleep(1000) - assert(availableChannelSlots(channel) === 4000) - sink.stop() - channel.stop() - transceiver.close() - } - - test("Multiple consumers") { - testMultipleConsumers(failSome = false) - } - - test("Multiple consumers with some failures") { - testMultipleConsumers(failSome = true) - } - - def testMultipleConsumers(failSome: Boolean): Unit = { - implicit val executorContext = ExecutionContext - .fromExecutorService(Executors.newFixedThreadPool(5)) - val (channel, sink) = initializeChannelAndSink() - channel.start() - sink.start() - (1 to 5).foreach(_ => putEvents(channel, eventsPerBatch)) - val port = sink.getPort - val address = new InetSocketAddress("0.0.0.0", port) - val transceiversAndClients = getTransceiverAndClient(address, 5) - val batchCounter = new CountDownLatch(5) - val counter = new AtomicInteger(0) - transceiversAndClients.foreach(x => { - Future { - val client = x._2 - val events = client.getEventBatch(1000) - if (!failSome || counter.getAndIncrement() % 2 == 0) { - client.ack(events.getSequenceNumber) - } else { - client.nack(events.getSequenceNumber) - throw new RuntimeException("Sending NACK for failure!") - } - events - }.onComplete { - case Success(events) => - assert(events.getEvents.size() === 1000) - batchCounter.countDown() - case Failure(t) => - // Don't re-throw the exception, causes a nasty unnecessary stack trace on stdout - batchCounter.countDown() - } - }) - batchCounter.await() - TimeUnit.SECONDS.sleep(1) // Allow the sink to commit the transactions. - executorContext.shutdown() - if(failSome) { - assert(availableChannelSlots(channel) === 3000) - } else { - assertChannelIsEmpty(channel) - } - sink.stop() - channel.stop() - transceiversAndClients.foreach(x => x._1.close()) - } - - private def initializeChannelAndSink(overrides: Map[String, String] = Map.empty): (MemoryChannel, - SparkSink) = { - val channel = new MemoryChannel() - val channelContext = new Context() - - channelContext.put("capacity", channelCapacity.toString) - channelContext.put("transactionCapacity", 1000.toString) - channelContext.put("keep-alive", 0.toString) - channelContext.putAll(overrides) - channel.configure(channelContext) - - val sink = new SparkSink() - val sinkContext = new Context() - sinkContext.put(SparkSinkConfig.CONF_HOSTNAME, "0.0.0.0") - sinkContext.put(SparkSinkConfig.CONF_PORT, 0.toString) - sink.configure(sinkContext) - sink.setChannel(channel) - (channel, sink) - } - - private def putEvents(ch: MemoryChannel, count: Int): Unit = { - val tx = ch.getTransaction - tx.begin() - (1 to count).foreach(x => ch.put(EventBuilder.withBody(x.toString.getBytes))) - tx.commit() - tx.close() - } - - private def getTransceiverAndClient(address: InetSocketAddress, - count: Int): Seq[(NettyTransceiver, SparkFlumeProtocol.Callback)] = { - - (1 to count).map(_ => { - lazy val channelFactoryExecutor = - Executors.newCachedThreadPool(new ThreadFactoryBuilder().setDaemon(true). - setNameFormat("Flume Receiver Channel Thread - %d").build()) - lazy val channelFactory = - new NioClientSocketChannelFactory(channelFactoryExecutor, channelFactoryExecutor) - val transceiver = new NettyTransceiver(address, channelFactory) - val client = SpecificRequestor.getClient(classOf[SparkFlumeProtocol.Callback], transceiver) - (transceiver, client) - }) - } - - private def assertChannelIsEmpty(channel: MemoryChannel): Unit = { - assert(availableChannelSlots(channel) === channelCapacity) - } - - private def availableChannelSlots(channel: MemoryChannel): Int = { - val queueRemaining = channel.getClass.getDeclaredField("queueRemaining") - queueRemaining.setAccessible(true) - val m = queueRemaining.get(channel).getClass.getDeclaredMethod("availablePermits") - m.invoke(queueRemaining.get(channel)).asInstanceOf[Int] - } -} From e1535ad3c6f7400f2b7915ea91da9c60510557ba Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Thu, 21 Aug 2014 05:54:41 +0000 Subject: [PATCH 227/231] [maven-release-plugin] prepare release v1.1.0-snapshot2 --- assembly/pom.xml | 6 +++--- bagel/pom.xml | 2 +- core/pom.xml | 2 +- examples/pom.xml | 2 +- external/flume-sink/pom.xml | 2 +- external/flume/pom.xml | 2 +- external/kafka/pom.xml | 2 +- external/mqtt/pom.xml | 2 +- external/twitter/pom.xml | 2 +- external/zeromq/pom.xml | 2 +- extras/kinesis-asl/pom.xml | 2 +- extras/spark-ganglia-lgpl/pom.xml | 2 +- graphx/pom.xml | 2 +- mllib/pom.xml | 2 +- pom.xml | 9 ++++----- repl/pom.xml | 2 +- sql/catalyst/pom.xml | 5 ++--- sql/core/pom.xml | 5 ++--- sql/hive-thriftserver/pom.xml | 5 ++--- sql/hive/pom.xml | 5 ++--- streaming/pom.xml | 2 +- tools/pom.xml | 2 +- yarn/pom.xml | 2 +- yarn/stable/pom.xml | 2 +- 24 files changed, 33 insertions(+), 38 deletions(-) diff --git a/assembly/pom.xml b/assembly/pom.xml index 9fbb037115db3..799f8d9b3c815 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.0-SNAPSHOT + 1.1.0 ../pom.xml @@ -124,8 +124,8 @@ log4j.properties - - + + diff --git a/bagel/pom.xml b/bagel/pom.xml index bd51b112e26fa..8eec7e5dd23b5 100644 --- a/bagel/pom.xml +++ b/bagel/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.0-SNAPSHOT + 1.1.0 ../pom.xml diff --git a/core/pom.xml b/core/pom.xml index 6d8be37037729..83e60268afbd2 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.0-SNAPSHOT + 1.1.0 ../pom.xml diff --git a/examples/pom.xml b/examples/pom.xml index 8c4c128bb484d..9bde90ed116e1 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.0-SNAPSHOT + 1.1.0 ../pom.xml diff --git a/external/flume-sink/pom.xml b/external/flume-sink/pom.xml index 0c68defa5e101..9d8a36d9d5b82 100644 --- a/external/flume-sink/pom.xml +++ b/external/flume-sink/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.0-SNAPSHOT + 1.1.0 ../../pom.xml diff --git a/external/flume/pom.xml b/external/flume/pom.xml index c532705f3950c..6563f4d73da01 100644 --- a/external/flume/pom.xml +++ b/external/flume/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.0-SNAPSHOT + 1.1.0 ../../pom.xml diff --git a/external/kafka/pom.xml b/external/kafka/pom.xml index 4e2275ab238f7..e3df55355c8d5 100644 --- a/external/kafka/pom.xml +++ b/external/kafka/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.0-SNAPSHOT + 1.1.0 ../../pom.xml diff --git a/external/mqtt/pom.xml b/external/mqtt/pom.xml index dc48a08c93de2..1f9e52b19b24e 100644 --- a/external/mqtt/pom.xml +++ b/external/mqtt/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.0-SNAPSHOT + 1.1.0 ../../pom.xml diff --git a/external/twitter/pom.xml b/external/twitter/pom.xml index b93ad016f84f0..d28741a9524dc 100644 --- a/external/twitter/pom.xml +++ b/external/twitter/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.0-SNAPSHOT + 1.1.0 ../../pom.xml diff --git a/external/zeromq/pom.xml b/external/zeromq/pom.xml index 22c1fff23d9a2..7e7609985d2c0 100644 --- a/external/zeromq/pom.xml +++ b/external/zeromq/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.0-SNAPSHOT + 1.1.0 ../../pom.xml diff --git a/extras/kinesis-asl/pom.xml b/extras/kinesis-asl/pom.xml index a54b34235dfb4..4d0eac8956955 100644 --- a/extras/kinesis-asl/pom.xml +++ b/extras/kinesis-asl/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent - 1.1.0-SNAPSHOT + 1.1.0 ../../pom.xml diff --git a/extras/spark-ganglia-lgpl/pom.xml b/extras/spark-ganglia-lgpl/pom.xml index a5b162a0482e4..44f50aeefa62e 100644 --- a/extras/spark-ganglia-lgpl/pom.xml +++ b/extras/spark-ganglia-lgpl/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent - 1.1.0-SNAPSHOT + 1.1.0 ../../pom.xml diff --git a/graphx/pom.xml b/graphx/pom.xml index 6dd52fc618b1e..c677f361cf1f2 100644 --- a/graphx/pom.xml +++ b/graphx/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.0-SNAPSHOT + 1.1.0 ../pom.xml diff --git a/mllib/pom.xml b/mllib/pom.xml index c7a1e2ae75c84..873f65105b1c5 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.0-SNAPSHOT + 1.1.0 ../pom.xml diff --git a/pom.xml b/pom.xml index b8df3d025cfbf..be190d6548e80 100644 --- a/pom.xml +++ b/pom.xml @@ -16,8 +16,7 @@ ~ limitations under the License. --> - + 4.0.0 org.apache @@ -26,7 +25,7 @@ org.apache.spark spark-parent - 1.1.0-SNAPSHOT + 1.1.0 pom Spark Project Parent POM http://spark.apache.org/ @@ -41,7 +40,7 @@ scm:git:git@github.com:apache/spark.git scm:git:https://git-wip-us.apache.org/repos/asf/spark.git scm:git:git@github.com:apache/spark.git - HEAD + v1.1.0-snapshot2 @@ -880,7 +879,7 @@ . ${project.build.directory}/SparkTestSuite.txt -Xmx3g -XX:MaxPermSize=${MaxPermGen} -XX:ReservedCodeCacheSize=512m - + true ${session.executionRootDirectory} diff --git a/repl/pom.xml b/repl/pom.xml index 68f4504450778..b7458eeb270dd 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.0-SNAPSHOT + 1.1.0 ../pom.xml diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index 58d44e7923bee..7c9e5b284e0d9 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -16,13 +16,12 @@ ~ limitations under the License. --> - + 4.0.0 org.apache.spark spark-parent - 1.1.0-SNAPSHOT + 1.1.0 ../../pom.xml diff --git a/sql/core/pom.xml b/sql/core/pom.xml index c8016e41256d5..d797753f12151 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -16,13 +16,12 @@ ~ limitations under the License. --> - + 4.0.0 org.apache.spark spark-parent - 1.1.0-SNAPSHOT + 1.1.0 ../../pom.xml diff --git a/sql/hive-thriftserver/pom.xml b/sql/hive-thriftserver/pom.xml index c6f60c18804a4..d75d2e514544d 100644 --- a/sql/hive-thriftserver/pom.xml +++ b/sql/hive-thriftserver/pom.xml @@ -16,13 +16,12 @@ ~ limitations under the License. --> - + 4.0.0 org.apache.spark spark-parent - 1.1.0-SNAPSHOT + 1.1.0 ../../pom.xml diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index 30ff277e67c88..dc5cfbac8212d 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -16,13 +16,12 @@ ~ limitations under the License. --> - + 4.0.0 org.apache.spark spark-parent - 1.1.0-SNAPSHOT + 1.1.0 ../../pom.xml diff --git a/streaming/pom.xml b/streaming/pom.xml index 1072f74aea0d9..31c096380a7c1 100644 --- a/streaming/pom.xml +++ b/streaming/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.0-SNAPSHOT + 1.1.0 ../pom.xml diff --git a/tools/pom.xml b/tools/pom.xml index 97abb6b2b63e0..ce3629443ed98 100644 --- a/tools/pom.xml +++ b/tools/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent - 1.1.0-SNAPSHOT + 1.1.0 ../pom.xml diff --git a/yarn/pom.xml b/yarn/pom.xml index 3faaf053634d6..274be3a563641 100644 --- a/yarn/pom.xml +++ b/yarn/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent - 1.1.0-SNAPSHOT + 1.1.0 ../pom.xml diff --git a/yarn/stable/pom.xml b/yarn/stable/pom.xml index b6c8456d06684..64fb00ac71b60 100644 --- a/yarn/stable/pom.xml +++ b/yarn/stable/pom.xml @@ -20,7 +20,7 @@ org.apache.spark yarn-parent_2.10 - 1.1.0-SNAPSHOT + 1.1.0 ../pom.xml From 9af3fb7385d1f9f221962f1d2d725ff79bd82033 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Thu, 21 Aug 2014 05:54:48 +0000 Subject: [PATCH 228/231] [maven-release-plugin] prepare for next development iteration --- assembly/pom.xml | 2 +- bagel/pom.xml | 2 +- core/pom.xml | 2 +- examples/pom.xml | 2 +- external/flume-sink/pom.xml | 2 +- external/flume/pom.xml | 2 +- external/kafka/pom.xml | 2 +- external/mqtt/pom.xml | 2 +- external/twitter/pom.xml | 2 +- external/zeromq/pom.xml | 2 +- extras/kinesis-asl/pom.xml | 2 +- extras/spark-ganglia-lgpl/pom.xml | 2 +- graphx/pom.xml | 2 +- mllib/pom.xml | 2 +- pom.xml | 4 ++-- repl/pom.xml | 2 +- sql/catalyst/pom.xml | 2 +- sql/core/pom.xml | 2 +- sql/hive-thriftserver/pom.xml | 2 +- sql/hive/pom.xml | 2 +- streaming/pom.xml | 2 +- tools/pom.xml | 2 +- yarn/pom.xml | 2 +- yarn/stable/pom.xml | 2 +- 24 files changed, 25 insertions(+), 25 deletions(-) diff --git a/assembly/pom.xml b/assembly/pom.xml index 799f8d9b3c815..4709b7dbddfea 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.0 + 1.1.1-SNAPSHOT ../pom.xml diff --git a/bagel/pom.xml b/bagel/pom.xml index 8eec7e5dd23b5..f29540b239c73 100644 --- a/bagel/pom.xml +++ b/bagel/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.0 + 1.1.1-SNAPSHOT ../pom.xml diff --git a/core/pom.xml b/core/pom.xml index 83e60268afbd2..debc4dd703d9a 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.0 + 1.1.1-SNAPSHOT ../pom.xml diff --git a/examples/pom.xml b/examples/pom.xml index 9bde90ed116e1..f35d3d6a788e3 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.0 + 1.1.1-SNAPSHOT ../pom.xml diff --git a/external/flume-sink/pom.xml b/external/flume-sink/pom.xml index 9d8a36d9d5b82..7f1172ec2092d 100644 --- a/external/flume-sink/pom.xml +++ b/external/flume-sink/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.0 + 1.1.1-SNAPSHOT ../../pom.xml diff --git a/external/flume/pom.xml b/external/flume/pom.xml index 6563f4d73da01..b127136e3f5a0 100644 --- a/external/flume/pom.xml +++ b/external/flume/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.0 + 1.1.1-SNAPSHOT ../../pom.xml diff --git a/external/kafka/pom.xml b/external/kafka/pom.xml index e3df55355c8d5..5123d0554639c 100644 --- a/external/kafka/pom.xml +++ b/external/kafka/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.0 + 1.1.1-SNAPSHOT ../../pom.xml diff --git a/external/mqtt/pom.xml b/external/mqtt/pom.xml index 1f9e52b19b24e..9c00bfc8429a4 100644 --- a/external/mqtt/pom.xml +++ b/external/mqtt/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.0 + 1.1.1-SNAPSHOT ../../pom.xml diff --git a/external/twitter/pom.xml b/external/twitter/pom.xml index d28741a9524dc..1b9ef4af0c2ed 100644 --- a/external/twitter/pom.xml +++ b/external/twitter/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.0 + 1.1.1-SNAPSHOT ../../pom.xml diff --git a/external/zeromq/pom.xml b/external/zeromq/pom.xml index 7e7609985d2c0..60292a2683212 100644 --- a/external/zeromq/pom.xml +++ b/external/zeromq/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.0 + 1.1.1-SNAPSHOT ../../pom.xml diff --git a/extras/kinesis-asl/pom.xml b/extras/kinesis-asl/pom.xml index 4d0eac8956955..58b995c5e7005 100644 --- a/extras/kinesis-asl/pom.xml +++ b/extras/kinesis-asl/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent - 1.1.0 + 1.1.1-SNAPSHOT ../../pom.xml diff --git a/extras/spark-ganglia-lgpl/pom.xml b/extras/spark-ganglia-lgpl/pom.xml index 44f50aeefa62e..02c9676fb086a 100644 --- a/extras/spark-ganglia-lgpl/pom.xml +++ b/extras/spark-ganglia-lgpl/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent - 1.1.0 + 1.1.1-SNAPSHOT ../../pom.xml diff --git a/graphx/pom.xml b/graphx/pom.xml index c677f361cf1f2..656478583fac2 100644 --- a/graphx/pom.xml +++ b/graphx/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.0 + 1.1.1-SNAPSHOT ../pom.xml diff --git a/mllib/pom.xml b/mllib/pom.xml index 873f65105b1c5..74f528f030987 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.0 + 1.1.1-SNAPSHOT ../pom.xml diff --git a/pom.xml b/pom.xml index be190d6548e80..1479326af0ed9 100644 --- a/pom.xml +++ b/pom.xml @@ -25,7 +25,7 @@ org.apache.spark spark-parent - 1.1.0 + 1.1.1-SNAPSHOT pom Spark Project Parent POM http://spark.apache.org/ @@ -40,7 +40,7 @@ scm:git:git@github.com:apache/spark.git scm:git:https://git-wip-us.apache.org/repos/asf/spark.git scm:git:git@github.com:apache/spark.git - v1.1.0-snapshot2 + HEAD diff --git a/repl/pom.xml b/repl/pom.xml index b7458eeb270dd..8748ada36f57a 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.0 + 1.1.1-SNAPSHOT ../pom.xml diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index 7c9e5b284e0d9..e2356381c07fb 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.0 + 1.1.1-SNAPSHOT ../../pom.xml diff --git a/sql/core/pom.xml b/sql/core/pom.xml index d797753f12151..3efea9ab8b247 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.0 + 1.1.1-SNAPSHOT ../../pom.xml diff --git a/sql/hive-thriftserver/pom.xml b/sql/hive-thriftserver/pom.xml index d75d2e514544d..c264ff4ec92e5 100644 --- a/sql/hive-thriftserver/pom.xml +++ b/sql/hive-thriftserver/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.0 + 1.1.1-SNAPSHOT ../../pom.xml diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index dc5cfbac8212d..1e689e6d6dcf2 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.0 + 1.1.1-SNAPSHOT ../../pom.xml diff --git a/streaming/pom.xml b/streaming/pom.xml index 31c096380a7c1..c0ce0d7c7478d 100644 --- a/streaming/pom.xml +++ b/streaming/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.0 + 1.1.1-SNAPSHOT ../pom.xml diff --git a/tools/pom.xml b/tools/pom.xml index ce3629443ed98..c601fd5fbbee2 100644 --- a/tools/pom.xml +++ b/tools/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent - 1.1.0 + 1.1.1-SNAPSHOT ../pom.xml diff --git a/yarn/pom.xml b/yarn/pom.xml index 274be3a563641..18f27b827ff1a 100644 --- a/yarn/pom.xml +++ b/yarn/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent - 1.1.0 + 1.1.1-SNAPSHOT ../pom.xml diff --git a/yarn/stable/pom.xml b/yarn/stable/pom.xml index 64fb00ac71b60..2ba3baf0e3b2e 100644 --- a/yarn/stable/pom.xml +++ b/yarn/stable/pom.xml @@ -20,7 +20,7 @@ org.apache.spark yarn-parent_2.10 - 1.1.0 + 1.1.1-SNAPSHOT ../pom.xml From da0a701204ae057581ed2d41eba5bb610e36c864 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Wed, 20 Aug 2014 12:18:41 -0700 Subject: [PATCH 229/231] BUILD: Bump Hadoop versions in the release build. Also, minor modifications to the MapR profile. --- dev/create-release/create-release.sh | 10 +++---- pom.xml | 39 +++++++++++++++++++++++++--- 2 files changed, 40 insertions(+), 9 deletions(-) diff --git a/dev/create-release/create-release.sh b/dev/create-release/create-release.sh index 905dec0ced383..eab6313733dfd 100755 --- a/dev/create-release/create-release.sh +++ b/dev/create-release/create-release.sh @@ -118,11 +118,11 @@ make_binary_release() { } make_binary_release "hadoop1" "-Phive -Dhadoop.version=1.0.4" & -make_binary_release "cdh4" "-Phive -Dhadoop.version=2.0.0-mr1-cdh4.2.0" & -make_binary_release "hadoop2" \ - "-Phive -Pyarn -Phadoop-2.2 -Dhadoop.version=2.2.0 -Pyarn.version=2.2.0" & -make_binary_release "hadoop2-without-hive" \ - "-Pyarn -Phadoop-2.2 -Dhadoop.version=2.2.0 -Pyarn.version=2.2.0" & +make_binary_release "hadoop2.3" "-Phadoop-2.3 -Phive -Pyarn" & +make_binary_release "hadoop2.4" "-Phadoop-2.4 -Phive -Pyarn" & +make_binary_release "hadoop2.4-without-hive" "-Phadoop-2.4 -Pyarn" & +make_binary_release "mapr3" "-Pmapr3 -Pyarn -Phive" & +make_binary_release "mapr4" "-Pmapr4 -Pyarn -Phive" & wait # Copy data diff --git a/pom.xml b/pom.xml index 1479326af0ed9..bc3aa060e9dfc 100644 --- a/pom.xml +++ b/pom.xml @@ -1114,18 +1114,49 @@ - mapr + mapr3 false 1.0.3-mapr-3.0.3 - 2.3.0-mapr-4.0.0-beta - 0.94.17-mapr-1403 - 3.4.5-mapr-1401 + 2.3.0-mapr-4.0.0-FCS + 0.94.17-mapr-1405 + 3.4.5-mapr-1406 + + mapr4 + + false + + + 2.3.0-mapr-4.0.0-FCS + 2.3.0-mapr-4.0.0-FCS + 0.94.17-mapr-1405-4.0.0-FCS + 3.4.5-mapr-1406 + + + + org.apache.curator + curator-recipes + 2.4.0 + + + org.apache.zookeeper + zookeeper + + + + + org.apache.zookeeper + zookeeper + 3.4.5-mapr-1406 + + + + hadoop-provided From 1e5d9cbb499199304aa8820114fa77dc7a3f0224 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Thu, 21 Aug 2014 00:17:29 -0700 Subject: [PATCH 230/231] [SPARK-2840] [mllib] DecisionTree doc update (Java, Python examples) Updated DecisionTree documentation, with examples for Java, Python. Added same Java example to code as well. CC: @mengxr @manishamde @atalwalkar Author: Joseph K. Bradley Closes #2063 from jkbradley/dt-docs and squashes the following commits: 2dd2c19 [Joseph K. Bradley] Last updates based on github review. 9dd1b6b [Joseph K. Bradley] Updated decision tree doc. d802369 [Joseph K. Bradley] Updates based on comments: cache data, corrected doc text. b9bee04 [Joseph K. Bradley] Updated DT examples 57eee9f [Joseph K. Bradley] Created JavaDecisionTree example from example in docs, and corrected doc example as needed. d939a92 [Joseph K. Bradley] Updated DecisionTree documentation. Added Java, Python examples. (cherry picked from commit 050f8d01e47b9b67b02ce50d83fb7b4e528b7204) Signed-off-by: Xiangrui Meng --- docs/mllib-decision-tree.md | 352 ++++++++++++++---- .../examples/mllib/JavaDecisionTree.java | 116 ++++++ 2 files changed, 399 insertions(+), 69 deletions(-) create mode 100644 examples/src/main/java/org/apache/spark/examples/mllib/JavaDecisionTree.java diff --git a/docs/mllib-decision-tree.md b/docs/mllib-decision-tree.md index c01a92a9a1b26..1166d9cd150c4 100644 --- a/docs/mllib-decision-tree.md +++ b/docs/mllib-decision-tree.md @@ -7,20 +7,26 @@ displayTitle: MLlib - Decision Tree * Table of contents {:toc} -Decision trees and their ensembles are popular methods for the machine learning tasks of +[Decision trees](http://en.wikipedia.org/wiki/Decision_tree_learning) +and their ensembles are popular methods for the machine learning tasks of classification and regression. Decision trees are widely used since they are easy to interpret, -handle categorical variables, extend to the multiclass classification setting, do not require +handle categorical features, extend to the multiclass classification setting, do not require feature scaling and are able to capture nonlinearities and feature interactions. Tree ensemble -algorithms such as decision forest and boosting are among the top performers for classification and +algorithms such as random forests and boosting are among the top performers for classification and regression tasks. +MLlib supports decision trees for binary and multiclass classification and for regression, +using both continuous and categorical features. The implementation partitions data by rows, +allowing distributed training with millions of instances. + ## Basic algorithm The decision tree is a greedy algorithm that performs a recursive binary partitioning of the feature -space by choosing a single element from the *best split set* where each element of the set maximizes -the information gain at a tree node. In other words, the split chosen at each tree node is chosen -from the set `$\underset{s}{\operatorname{argmax}} IG(D,s)$` where `$IG(D,s)$` is the information -gain when a split `$s$` is applied to a dataset `$D$`. +space. The tree predicts the same label for each bottommost (leaf) partition. +Each partition is chosen greedily by selecting the *best split* from a set of possible splits, +in order to maximize the information gain at a tree node. In other words, the split chosen at each +tree node is chosen from the set `$\underset{s}{\operatorname{argmax}} IG(D,s)$` where `$IG(D,s)$` +is the information gain when a split `$s$` is applied to a dataset `$D$`. ### Node impurity and information gain @@ -52,9 +58,10 @@ impurity measure for regression (variance). -The *information gain* is the difference in the parent node impurity and the weighted sum of the two -child node impurities. Assuming that a split $s$ partitions the dataset `$D$` of size `$N$` into two -datasets `$D_{left}$` and `$D_{right}$` of sizes `$N_{left}$` and `$N_{right}$`, respectively: +The *information gain* is the difference between the parent node impurity and the weighted sum of +the two child node impurities. Assuming that a split $s$ partitions the dataset `$D$` of size `$N$` +into two datasets `$D_{left}$` and `$D_{right}$` of sizes `$N_{left}$` and `$N_{right}$`, +respectively, the information gain is: `$IG(D,s) = Impurity(D) - \frac{N_{left}}{N} Impurity(D_{left}) - \frac{N_{right}}{N} Impurity(D_{right})$` @@ -62,14 +69,15 @@ datasets `$D_{left}$` and `$D_{right}$` of sizes `$N_{left}$` and `$N_{right}$`, **Continuous features** -For small datasets in single machine implementations, the split candidates for each continuous +For small datasets in single-machine implementations, the split candidates for each continuous feature are typically the unique values for the feature. Some implementations sort the feature values and then use the ordered unique values as split candidates for faster tree calculations. -Finding ordered unique feature values is computationally intensive for large distributed -datasets. One can get an approximate set of split candidates by performing a quantile calculation -over a sampled fraction of the data. The ordered splits create "bins" and the maximum number of such -bins can be specified using the `maxBins` parameters. +Sorting feature values is expensive for large distributed datasets. +This implementation computes an approximate set of split candidates by performing a quantile +calculation over a sampled fraction of the data. +The ordered splits create "bins" and the maximum number of such +bins can be specified using the `maxBins` parameter. Note that the number of bins cannot be greater than the number of instances `$N$` (a rare scenario since the default `maxBins` value is 100). The tree algorithm automatically reduces the number of @@ -77,109 +85,315 @@ bins if the condition is not satisfied. **Categorical features** -For `$M$` categorical feature values, one could come up with `$2^(M-1)-1$` split candidates. For -binary classification, we can reduce the number of split candidates to `$M-1$` by ordering the -categorical feature values by the proportion of labels falling in one of the two classes (see -Section 9.2.4 in +For a categorical feature with `$M$` possible values (categories), one could come up with +`$2^{M-1}-1$` split candidates. For binary (0/1) classification and regression, +we can reduce the number of split candidates to `$M-1$` by ordering the +categorical feature values by the average label. (See Section 9.2.4 in [Elements of Statistical Machine Learning](http://statweb.stanford.edu/~tibs/ElemStatLearn/) for -details). For example, for a binary classification problem with one categorical feature with three -categories A, B and C with corresponding proportion of label 1 as 0.2, 0.6 and 0.4, the categorical -features are ordered as A followed by C followed B or A, C, B. The two split candidates are A \| C, B -and A , C \| B where \| denotes the split. A similar heuristic is used for multiclass classification -when `$2^(M-1)-1$` is greater than the number of bins -- the impurity for each categorical feature value -is used for ordering. +details.) For example, for a binary classification problem with one categorical feature with three +categories A, B and C whose corresponding proportions of label 1 are 0.2, 0.6 and 0.4, the categorical +features are ordered as A, C, B. The two split candidates are A \| C, B +and A , C \| B where \| denotes the split. + +In multiclass classification, all `$2^{M-1}-1$` possible splits are used whenever possible. +When `$2^{M-1}-1$` is greater than the `maxBins` parameter, we use a (heuristic) method +similar to the method used for binary classification and regression. +The `$M$` categorical feature values are ordered by impurity, +and the resulting `$M-1$` split candidates are considered. ### Stopping rule The recursive tree construction is stopped at a node when one of the two conditions is met: -1. The node depth is equal to the `maxDepth` training parameter +1. The node depth is equal to the `maxDepth` training parameter. 2. No split candidate leads to an information gain at the node. +## Implementation details + ### Max memory requirements -For faster processing, the decision tree algorithm performs simultaneous histogram computations for all nodes at each level of the tree. This could lead to high memory requirements at deeper levels of the tree leading to memory overflow errors. To alleviate this problem, a 'maxMemoryInMB' training parameter is provided which specifies the maximum amount of memory at the workers (twice as much at the master) to be allocated to the histogram computation. The default value is conservatively chosen to be 128 MB to allow the decision algorithm to work in most scenarios. Once the memory requirements for a level-wise computation crosses the `maxMemoryInMB` threshold, the node training tasks at each subsequent level is split into smaller tasks. +For faster processing, the decision tree algorithm performs simultaneous histogram computations for +all nodes at each level of the tree. This could lead to high memory requirements at deeper levels +of the tree, potentially leading to memory overflow errors. To alleviate this problem, a `maxMemoryInMB` +training parameter specifies the maximum amount of memory at the workers (twice as much at the +master) to be allocated to the histogram computation. The default value is conservatively chosen to +be 128 MB to allow the decision algorithm to work in most scenarios. Once the memory requirements +for a level-wise computation cross the `maxMemoryInMB` threshold, the node training tasks at each +subsequent level are split into smaller tasks. + +Note that, if you have a large amount of memory, increasing `maxMemoryInMB` can lead to faster +training by requiring fewer passes over the data. + +### Binning feature values + +Increasing `maxBins` allows the algorithm to consider more split candidates and make fine-grained +split decisions. However, it also increases computation and communication. + +Note that the `maxBins` parameter must be at least the maximum number of categories `$M$` for +any categorical feature. + +### Scaling -### Practical limitations +Computation scales approximately linearly in the number of training instances, +in the number of features, and in the `maxBins` parameter. +Communication scales approximately linearly in the number of features and in `maxBins`. -1. The implemented algorithm reads both sparse and dense data. However, it is not optimized for sparse input. -2. Python is not supported in this release. +The implemented algorithm reads both sparse and dense data. However, it is not optimized for sparse input. ## Examples ### Classification -The example below demonstrates how to load a CSV file, parse it as an RDD of `LabeledPoint` and then -perform classification using a decision tree using Gini impurity as an impurity measure and a +The example below demonstrates how to load a +[LIBSVM data file](http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/), +parse it as an RDD of `LabeledPoint` and then +perform classification using a decision tree with Gini impurity as an impurity measure and a maximum tree depth of 5. The training error is calculated to measure the algorithm accuracy.
    +
    {% highlight scala %} -import org.apache.spark.SparkContext import org.apache.spark.mllib.tree.DecisionTree -import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.linalg.Vectors -import org.apache.spark.mllib.tree.configuration.Algo._ -import org.apache.spark.mllib.tree.impurity.Gini - -// Load and parse the data file -val data = sc.textFile("data/mllib/sample_tree_data.csv") -val parsedData = data.map { line => - val parts = line.split(',').map(_.toDouble) - LabeledPoint(parts(0), Vectors.dense(parts.tail)) -} +import org.apache.spark.mllib.util.MLUtils -// Run training algorithm to build the model +// Load and parse the data file. +// Cache the data since we will use it again to compute training error. +val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").cache() + +// Train a DecisionTree model. +// Empty categoricalFeaturesInfo indicates all features are continuous. +val numClasses = 2 +val categoricalFeaturesInfo = Map[Int, Int]() +val impurity = "gini" val maxDepth = 5 -val model = DecisionTree.train(parsedData, Classification, Gini, maxDepth) +val maxBins = 100 + +val model = DecisionTree.trainClassifier(data, numClasses, categoricalFeaturesInfo, impurity, + maxDepth, maxBins) -// Evaluate model on training examples and compute training error -val labelAndPreds = parsedData.map { point => +// Evaluate model on training instances and compute training error +val labelAndPreds = data.map { point => val prediction = model.predict(point.features) (point.label, prediction) } -val trainErr = labelAndPreds.filter(r => r._1 != r._2).count.toDouble / parsedData.count +val trainErr = labelAndPreds.filter(r => r._1 != r._2).count.toDouble / data.count println("Training Error = " + trainErr) +println("Learned classification tree model:\n" + model) +{% endhighlight %} +
    + +
    +{% highlight java %} +import java.util.HashMap; +import scala.Tuple2; +import org.apache.spark.api.java.function.Function2; +import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.api.java.function.PairFunction; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.tree.DecisionTree; +import org.apache.spark.mllib.tree.model.DecisionTreeModel; +import org.apache.spark.mllib.util.MLUtils; +import org.apache.spark.SparkConf; + +SparkConf sparkConf = new SparkConf().setAppName("JavaDecisionTree"); +JavaSparkContext sc = new JavaSparkContext(sparkConf); + +// Load and parse the data file. +// Cache the data since we will use it again to compute training error. +String datapath = "data/mllib/sample_libsvm_data.txt"; +JavaRDD data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD().cache(); + +// Set parameters. +// Empty categoricalFeaturesInfo indicates all features are continuous. +Integer numClasses = 2; +HashMap categoricalFeaturesInfo = new HashMap(); +String impurity = "gini"; +Integer maxDepth = 5; +Integer maxBins = 100; + +// Train a DecisionTree model for classification. +final DecisionTreeModel model = DecisionTree.trainClassifier(data, numClasses, + categoricalFeaturesInfo, impurity, maxDepth, maxBins); + +// Evaluate model on training instances and compute training error +JavaPairRDD predictionAndLabel = + data.mapToPair(new PairFunction() { + @Override public Tuple2 call(LabeledPoint p) { + return new Tuple2(model.predict(p.features()), p.label()); + } + }); +Double trainErr = + 1.0 * predictionAndLabel.filter(new Function, Boolean>() { + @Override public Boolean call(Tuple2 pl) { + return !pl._1().equals(pl._2()); + } + }).count() / data.count(); +System.out.println("Training error: " + trainErr); +System.out.println("Learned classification tree model:\n" + model); +{% endhighlight %} +
    + +
    +{% highlight python %} +from pyspark.mllib.regression import LabeledPoint +from pyspark.mllib.tree import DecisionTree +from pyspark.mllib.util import MLUtils + +# Load and parse the data file into an RDD of LabeledPoint. +# Cache the data since we will use it again to compute training error. +data = MLUtils.loadLibSVMFile(sc, 'data/mllib/sample_libsvm_data.txt').cache() + +# Train a DecisionTree model. +# Empty categoricalFeaturesInfo indicates all features are continuous. +model = DecisionTree.trainClassifier(data, numClasses=2, categoricalFeaturesInfo={}, + impurity='gini', maxDepth=5, maxBins=100) + +# Evaluate model on training instances and compute training error +predictions = model.predict(data.map(lambda x: x.features)) +labelsAndPredictions = data.map(lambda lp: lp.label).zip(predictions) +trainErr = labelsAndPredictions.filter(lambda (v, p): v != p).count() / float(data.count()) +print('Training Error = ' + str(trainErr)) +print('Learned classification tree model:') +print(model) {% endhighlight %} + +Note: When making predictions for a dataset, it is more efficient to do batch prediction rather +than separately calling `predict` on each data point. This is because the Python code makes calls +to an underlying `DecisionTree` model in Scala.
    +
    ### Regression -The example below demonstrates how to load a CSV file, parse it as an RDD of `LabeledPoint` and then -perform regression using a decision tree using variance as an impurity measure and a maximum tree +The example below demonstrates how to load a +[LIBSVM data file](http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/), +parse it as an RDD of `LabeledPoint` and then +perform regression using a decision tree with variance as an impurity measure and a maximum tree depth of 5. The Mean Squared Error (MSE) is computed at the end to evaluate [goodness of fit](http://en.wikipedia.org/wiki/Goodness_of_fit).
    +
    {% highlight scala %} -import org.apache.spark.SparkContext import org.apache.spark.mllib.tree.DecisionTree -import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.linalg.Vectors -import org.apache.spark.mllib.tree.configuration.Algo._ -import org.apache.spark.mllib.tree.impurity.Variance - -// Load and parse the data file -val data = sc.textFile("data/mllib/sample_tree_data.csv") -val parsedData = data.map { line => - val parts = line.split(',').map(_.toDouble) - LabeledPoint(parts(0), Vectors.dense(parts.tail)) -} +import org.apache.spark.mllib.util.MLUtils -// Run training algorithm to build the model +// Load and parse the data file. +// Cache the data since we will use it again to compute training error. +val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").cache() + +// Train a DecisionTree model. +// Empty categoricalFeaturesInfo indicates all features are continuous. +val categoricalFeaturesInfo = Map[Int, Int]() +val impurity = "variance" val maxDepth = 5 -val model = DecisionTree.train(parsedData, Regression, Variance, maxDepth) +val maxBins = 100 + +val model = DecisionTree.trainRegressor(data, categoricalFeaturesInfo, impurity, + maxDepth, maxBins) -// Evaluate model on training examples and compute training error -val valuesAndPreds = parsedData.map { point => +// Evaluate model on training instances and compute training error +val labelsAndPredictions = data.map { point => val prediction = model.predict(point.features) (point.label, prediction) } -val MSE = valuesAndPreds.map{ case(v, p) => math.pow((v - p), 2)}.mean() -println("training Mean Squared Error = " + MSE) +val trainMSE = labelsAndPredictions.map{ case(v, p) => math.pow((v - p), 2)}.mean() +println("Training Mean Squared Error = " + trainMSE) +println("Learned regression tree model:\n" + model) {% endhighlight %}
    + +
    +{% highlight java %} +import java.util.HashMap; +import scala.Tuple2; +import org.apache.spark.api.java.function.Function2; +import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.api.java.function.PairFunction; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.tree.DecisionTree; +import org.apache.spark.mllib.tree.model.DecisionTreeModel; +import org.apache.spark.mllib.util.MLUtils; +import org.apache.spark.SparkConf; + +// Load and parse the data file. +// Cache the data since we will use it again to compute training error. +String datapath = "data/mllib/sample_libsvm_data.txt"; +JavaRDD data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD().cache(); + +SparkConf sparkConf = new SparkConf().setAppName("JavaDecisionTree"); +JavaSparkContext sc = new JavaSparkContext(sparkConf); + +// Set parameters. +// Empty categoricalFeaturesInfo indicates all features are continuous. +HashMap categoricalFeaturesInfo = new HashMap(); +String impurity = "variance"; +Integer maxDepth = 5; +Integer maxBins = 100; + +// Train a DecisionTree model. +final DecisionTreeModel model = DecisionTree.trainRegressor(data, + categoricalFeaturesInfo, impurity, maxDepth, maxBins); + +// Evaluate model on training instances and compute training error +JavaPairRDD predictionAndLabel = + data.mapToPair(new PairFunction() { + @Override public Tuple2 call(LabeledPoint p) { + return new Tuple2(model.predict(p.features()), p.label()); + } + }); +Double trainMSE = + predictionAndLabel.map(new Function, Double>() { + @Override public Double call(Tuple2 pl) { + Double diff = pl._1() - pl._2(); + return diff * diff; + } + }).reduce(new Function2() { + @Override public Double call(Double a, Double b) { + return a + b; + } + }) / data.count(); +System.out.println("Training Mean Squared Error: " + trainMSE); +System.out.println("Learned regression tree model:\n" + model); +{% endhighlight %} +
    + +
    +{% highlight python %} +from pyspark.mllib.regression import LabeledPoint +from pyspark.mllib.tree import DecisionTree +from pyspark.mllib.util import MLUtils + +# Load and parse the data file into an RDD of LabeledPoint. +# Cache the data since we will use it again to compute training error. +data = MLUtils.loadLibSVMFile(sc, 'data/mllib/sample_libsvm_data.txt').cache() + +# Train a DecisionTree model. +# Empty categoricalFeaturesInfo indicates all features are continuous. +model = DecisionTree.trainRegressor(data, categoricalFeaturesInfo={}, + impurity='variance', maxDepth=5, maxBins=100) + +# Evaluate model on training instances and compute training error +predictions = model.predict(data.map(lambda x: x.features)) +labelsAndPredictions = data.map(lambda lp: lp.label).zip(predictions) +trainMSE = labelsAndPredictions.map(lambda (v, p): (v - p) * (v - p)).sum() / float(data.count()) +print('Training Mean Squared Error = ' + str(trainMSE)) +print('Learned regression tree model:') +print(model) +{% endhighlight %} + +Note: When making predictions for a dataset, it is more efficient to do batch prediction rather +than separately calling `predict` on each data point. This is because the Python code makes calls +to an underlying `DecisionTree` model in Scala. +
    +
    diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaDecisionTree.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaDecisionTree.java new file mode 100644 index 0000000000000..e4468e8bf1744 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaDecisionTree.java @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +import java.util.HashMap; + +import scala.Tuple2; + +import org.apache.spark.api.java.function.Function2; +import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.api.java.function.PairFunction; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.tree.DecisionTree; +import org.apache.spark.mllib.tree.model.DecisionTreeModel; +import org.apache.spark.mllib.util.MLUtils; +import org.apache.spark.SparkConf; + +/** + * Classification and regression using decision trees. + */ +public final class JavaDecisionTree { + + public static void main(String[] args) { + String datapath = "data/mllib/sample_libsvm_data.txt"; + if (args.length == 1) { + datapath = args[0]; + } else if (args.length > 1) { + System.err.println("Usage: JavaDecisionTree "); + System.exit(1); + } + SparkConf sparkConf = new SparkConf().setAppName("JavaDecisionTree"); + JavaSparkContext sc = new JavaSparkContext(sparkConf); + + JavaRDD data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD().cache(); + + // Compute the number of classes from the data. + Integer numClasses = data.map(new Function() { + @Override public Double call(LabeledPoint p) { + return p.label(); + } + }).countByValue().size(); + + // Set parameters. + // Empty categoricalFeaturesInfo indicates all features are continuous. + HashMap categoricalFeaturesInfo = new HashMap(); + String impurity = "gini"; + Integer maxDepth = 5; + Integer maxBins = 100; + + // Train a DecisionTree model for classification. + final DecisionTreeModel model = DecisionTree.trainClassifier(data, numClasses, + categoricalFeaturesInfo, impurity, maxDepth, maxBins); + + // Evaluate model on training instances and compute training error + JavaPairRDD predictionAndLabel = + data.mapToPair(new PairFunction() { + @Override public Tuple2 call(LabeledPoint p) { + return new Tuple2(model.predict(p.features()), p.label()); + } + }); + Double trainErr = + 1.0 * predictionAndLabel.filter(new Function, Boolean>() { + @Override public Boolean call(Tuple2 pl) { + return !pl._1().equals(pl._2()); + } + }).count() / data.count(); + System.out.println("Training error: " + trainErr); + System.out.println("Learned classification tree model:\n" + model); + + // Train a DecisionTree model for regression. + impurity = "variance"; + final DecisionTreeModel regressionModel = DecisionTree.trainRegressor(data, + categoricalFeaturesInfo, impurity, maxDepth, maxBins); + + // Evaluate model on training instances and compute training error + JavaPairRDD regressorPredictionAndLabel = + data.mapToPair(new PairFunction() { + @Override public Tuple2 call(LabeledPoint p) { + return new Tuple2(regressionModel.predict(p.features()), p.label()); + } + }); + Double trainMSE = + regressorPredictionAndLabel.map(new Function, Double>() { + @Override public Double call(Tuple2 pl) { + Double diff = pl._1() - pl._2(); + return diff * diff; + } + }).reduce(new Function2() { + @Override public Double call(Double a, Double b) { + return a + b; + } + }) / data.count(); + System.out.println("Training Mean Squared Error: " + trainMSE); + System.out.println("Learned regression tree model:\n" + regressionModel); + + sc.stop(); + } +} From 1cb61ee2f1f911b02bba84436d87074d5b7a0831 Mon Sep 17 00:00:00 2001 From: chesterxgchen Date: Thu, 21 Aug 2014 13:45:45 -0700 Subject: [PATCH 231/231] SPARK-3175 : Branch-1.1 SBT build failed for Yarn-Alpha The issue is that the yarn/alpha/pom.xml using 1.1.0 instead of 1.1.1-SNAPSHOT version. update the pom.xml to 1.1.1-SNAPSHOT (same as yarn/stable/pom.xml) --- yarn/alpha/pom.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/yarn/alpha/pom.xml b/yarn/alpha/pom.xml index 72d9b1606ad9c..26d2926b11cae 100644 --- a/yarn/alpha/pom.xml +++ b/yarn/alpha/pom.xml @@ -20,7 +20,7 @@ org.apache.spark yarn-parent_2.10 - 1.1.0 + 1.1.1-SNAPSHOT ../pom.xml