From f8253520045d90c75b143d810edbb746f86cad8c Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Wed, 30 Jul 2014 14:48:41 -0700 Subject: [PATCH 01/21] Wrote Python API and example for DecisionTree. Also added toString, depth, and numNodes methods to DecisionTreeModel. --- examples/src/main/python/mllib/tree.py | 76 ++++++ .../mllib/api/python/PythonMLLibAPI.scala | 79 +++++++ .../mllib/tree/model/DecisionTreeModel.scala | 28 +++ .../apache/spark/mllib/tree/model/Node.scala | 55 +++++ python/pyspark/mllib/tree.py | 217 ++++++++++++++++++ 5 files changed, 455 insertions(+) create mode 100755 examples/src/main/python/mllib/tree.py create mode 100644 python/pyspark/mllib/tree.py diff --git a/examples/src/main/python/mllib/tree.py b/examples/src/main/python/mllib/tree.py new file mode 100755 index 0000000000000..f7af896d3e57c --- /dev/null +++ b/examples/src/main/python/mllib/tree.py @@ -0,0 +1,76 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Decision tree classification and regression using MLlib. +""" + +import sys + +from operator import add + +from pyspark import SparkContext +from pyspark.mllib.regression import LabeledPoint +from pyspark.mllib.tree import DecisionTree + + +# Parse a line of text into an MLlib LabeledPoint object +def parsePoint(line): + values = [float(s) for s in line.split(',')] + if values[0] == -1: # Convert -1 labels to 0 for MLlib + values[0] = 0 + return LabeledPoint(values[0], values[1:]) + +# Return accuracy of DecisionTreeModel on the given RDD[LabeledPoint]. +def getAccuracy(dtModel, data): + seqOp = (lambda acc, x: acc + (x[0] == x[1])) + trainCorrect = \ + dtModel.predict(data).zip(data.map((lambda p => p.label))).aggregate(0, seqOp, add) + return trainCorrect / (0.0 + data.count()) + + +if __name__ == "__main__": + if len(sys.argv) != 1: + print >> sys.stderr, "Usage: logistic_regression" + exit(-1) + sc = SparkContext(appName="PythonDT") + + # Load data. + dataPath = 'data/mllib/sample_tree_data.csv' + points = sc.textFile(dataPath).map(parsePoint) + + # Train a classifier. + model = DecisionTree.trainClassifier(points, numClasses=2) + # Print learned tree. + print "Model numNodes: " + model.numNodes() + "\n" + print "Model depth: " + model.depth() + "\n" + print model + # Check accuracy. + print "Training accuracy: " + getAccuracy(model, points) + "\n" + + # Switch labels and first feature to create a regression dataset with categorical features. + """ + datasetInfo = DatasetInfo(numClasses=0, numFeatures=numFeatures) + dtParams = DecisionTreeRegressor.defaultParams() + model = DecisionTreeRegressor.train(points, datasetInfo, dtParams) + # Print learned tree. + print "Model numNodes: " + model.numNodes() + "\n" + print "Model depth: " + model.depth() + "\n" + print model + # Check error. + print "Training accuracy: " + getAccuracy(model, points) + "\n" + """ diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 954621ee8b933..b14ea996be1c7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -19,6 +19,8 @@ package org.apache.spark.mllib.api.python import java.nio.{ByteBuffer, ByteOrder} +import scala.collection.JavaConversions._ + import org.apache.spark.annotation.DeveloperApi import org.apache.spark.api.java.{JavaSparkContext, JavaRDD} import org.apache.spark.mllib.classification._ @@ -26,8 +28,14 @@ import org.apache.spark.mllib.clustering._ import org.apache.spark.mllib.linalg.{SparseVector, Vector, Vectors} import org.apache.spark.mllib.recommendation._ import org.apache.spark.mllib.regression._ +import org.apache.spark.mllib.tree.configuration.Algo._ +import org.apache.spark.mllib.tree.configuration.Strategy +import org.apache.spark.mllib.tree.DecisionTree +import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Impurity, Variance} +import org.apache.spark.mllib.tree.model.DecisionTreeModel import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD +import org.apache.spark.util.Utils /** * :: DeveloperApi :: @@ -453,4 +461,75 @@ class PythonMLLibAPI extends Serializable { val ratings = ratingsBytesJRDD.rdd.map(unpackRating) ALS.trainImplicit(ratings, rank, iterations, lambda, blocks, alpha) } + + /** + * Java stub for Python mllib DecisionTree.train(). + * This stub returns a handle to the Java object instead of the content of the Java object. + * Extra care needs to be taken in the Python code to ensure it gets freed on exit; + * see the Py4J documentation. + * @param dataBytesJRDD Training data + * @param categoricalFeaturesInfoJMap Categorical features info, as Java map + */ + def trainDecisionTreeModel( + dataBytesJRDD: JavaRDD[Array[Byte]], + algoStr: String, + numClasses: Int, + categoricalFeaturesInfoJMap: java.util.Map[Int,Int], + impurityStr: String, + maxDepth: Int, + maxBins: Int): DecisionTreeModel = { + + val data = dataBytesJRDD.rdd.map(deserializeLabeledPoint) + + val algo: Algo = algoStr match { + case "classification" => Classification + case "regression" => Regression + case _ => throw new IllegalArgumentException(s"Bad algoStr parameter: $algoStr") + } + val impurity: Impurity = impurityStr match { + case "gini" => Gini + case "entropy" => Entropy + case "variance" => Variance + case _ => throw new IllegalArgumentException(s"Bad impurityStr parameter: $impurityStr") + } + + val strategy = new Strategy( + algo = algo, + impurity = impurity, + maxDepth = maxDepth, + numClassesForClassification = numClasses, + maxBins = maxBins, + categoricalFeaturesInfo = categoricalFeaturesInfoJMap.toMap) + + DecisionTree.train(data, strategy) + } + + /** + * Predict the label of the given data point. + * This is a Java stub for python DecisionTreeModel.predict() + * + * @param featuresBytes Serialized feature vector for data point + * @return predicted label + */ + def predictDecisionTreeModel( + model: DecisionTreeModel, + featuresBytes: Array[Byte]): Double = { + val features: Vector = deserializeDoubleVector(featuresBytes) + model.predict(features) + } + + /** + * Predict the labels of the given data points. + * This is a Java stub for python DecisionTreeModel.predict() + * + * @param dataJRDD A JavaRDD with serialized feature vectors + * @return JavaRDD of serialized predictions + */ + def predictDecisionTreeModel( + model: DecisionTreeModel, + dataJRDD: JavaRDD[Array[Byte]]): JavaRDD[Array[Byte]] = { + val data = dataJRDD.rdd.map(xBytes => deserializeDoubleVector(xBytes)) + model.predict(data).map(Utils.serialize(_)) + } + } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala index bf692ca8c4bd7..156c86b33966b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala @@ -50,4 +50,32 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable def predict(features: RDD[Vector]): RDD[Double] = { features.map(x => predict(x)) } + + /** + * Get number of nodes in tree, including leaf nodes. + */ + def numNodes: Int = { + topNode.numNodesRecursive + } + + /** + * Get depth of tree. + * E.g.: Depth 0 means 1 leaf node. Depth 1 means 1 internal node and 2 leaf nodes. + */ + def depth: Int = { + topNode.depthRecursive + } + + /** + * Print full model. + */ + override def toString: String = algo match { + case Classification => + s"DecisionTreeModel classifier\n" + topNode.toStringRecursive(2) + case Regression => + s"DecisionTreeModel regressor\n" + topNode.toStringRecursive(2) + case _ => throw new IllegalArgumentException( + s"DecisionTreeModel given unknown algo parameter: $algo.") + } + } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala index 682f213f411a7..b27d546728848 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala @@ -91,4 +91,59 @@ class Node ( } } } + + /** + * Get number of nodes in tree from this node, including leaf nodes. + */ + def numNodesRecursive: Int = { + if (isLeaf) { + 1 + } else { + 1 + leftNode.get.numNodesRecursive + rightNode.get.numNodesRecursive + } + } + + /** + * Get depth of tree from this node. + * E.g.: Depth 0 means this is a leaf node. + */ + def depthRecursive: Int = { + if (isLeaf) { + 0 + } else { + 1 + math.max(leftNode.get.depthRecursive, rightNode.get.depthRecursive) + } + } + + /** + * Recursive print function. + * @param indentFactor The number of spaces to add to each level of indentation. + */ + def toStringRecursive(indentFactor: Int = 0): String = { + + def splitToString(split: Split, left: Boolean) : String = { + split.featureType match { + case Continuous => if (left) { + s"(feature ${split.feature} <= ${split.threshold})" + } else { + s"(feature ${split.feature} > ${split.threshold})" + } + case Categorical => if (left) { + s"(feature ${split.feature} in ${split.categories})" + } else { + s"(feature ${split.feature} not in ${split.categories})" + } + } + } + val prefix: String = " " * indentFactor + if (isLeaf) { + prefix + s"Predict: $predict\n" + } else { + prefix + s"If ${splitToString(split.get, left=true)}\n" + + leftNode.get.toStringRecursive(indentFactor + 1) + + prefix + s"Else ${splitToString(split.get, left=false)}\n" + + rightNode.get.toStringRecursive(indentFactor + 1) + } + } + } diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py new file mode 100644 index 0000000000000..c04f17e1be23b --- /dev/null +++ b/python/pyspark/mllib/tree.py @@ -0,0 +1,217 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from py4j.java_collections import MapConverter + +from pyspark import SparkContext, RDD +from pyspark.mllib._common import \ + _convert_vector, \ + _dot, _get_unmangled_rdd, _get_unmangled_double_vector_rdd, \ + _serialize_double_matrix, _deserialize_double_matrix, \ + _serialize_double_vector, _deserialize_double_vector, \ + _deserialize_labeled_point, \ + _get_initial_weights, _serialize_rating, _regression_train_wrapper, \ + _linear_predictor_typecheck, _get_unmangled_labeled_point_rdd +from pyspark.mllib.linalg import SparseVector +from pyspark.mllib.regression import LabeledPoint +from pyspark.serializers import NoOpSerializer + +class DecisionTreeModel(object): + """ + A decision tree model for classification or regression. + + WARNING: This is an experimental API. It will probably be modified for Spark v1.2. + + # TODO: UPDATE: + >>> from numpy import array, ndarray + >>> from pyspark.mllib.regression import LabeledPoint + >>> from pyspark.mllib.tree import DecisionTree + >>> from pyspark.mllib.linalg import SparseVector + >>> data = [ + ... LabeledPoint(0.0, [0.0]), + ... LabeledPoint(1.0, [1.0]), + ... LabeledPoint(1.0, [2.0]), + ... LabeledPoint(1.0, [3.0]) + ... ] + >>> datasetInfo = DatasetInfo(2, 1) + >>> params = DecisionTreeClassifier.defaultParams() + >>> dtLearner = DecisionTreeClassifier(params) + >>> model = dtLearner.run(sc.parallelize(data), datasetInfo) + >>> model.predict(array([1.0])) > 0 + True + >>> model.predict(array([0.0])) == 0 + True + >>> sparse_data = [ + ... LabeledPoint(0.0, SparseVector(2, {0: 0.0})), + ... LabeledPoint(1.0, SparseVector(2, {1: 1.0})), + ... LabeledPoint(0.0, SparseVector(2, {0: 0.0})), + ... LabeledPoint(1.0, SparseVector(2, {1: 2.0})) + ... ] + >>> datasetInfo = DatasetInfo(2, 2) + >>> model = dtLearner.run(sc.parallelize(sparse_data), datasetInfo) + >>> model.predict(array([0.0, 1.0])) == 1 + True + >>> model.predict(array([0.0, 0.0])) == 0 + True + >>> model.predict(SparseVector(2, {1: 1.0})) == 1 + True + >>> model.predict(SparseVector(2, {1: 0.0})) == 0 + True + """ + + def __init__(self, sc, java_model): + """ + :param sc: Spark context + :param java_model: Handle to Java model object + """ + self._sc = sc + self._java_model = java_model + + def __del__(self): + self._sc._gateway.detach(self._java_model) + + def predict(self, x): + """ + :param x: Either one data point (feature vector), or a dataset (RDD of feature vectors) + """ + pythonAPI = self._sc._jvm.PythonMLLibAPI() + if type(x) == RDD: + # Bulk prediction + dataBytes = _get_unmangled_double_vector_rdd(x) + jSerializedPreds = pythonAPI.predictDecisionTreeModel(self._java_model, dataBytes._jrdd) + serializedPreds = RDD(jSerializedPreds, self._sc, NoOpSerializer()) + return serializedPreds.map(lambda bytes: _deserialize_labeled_point(bytearray(bytes))) + else: + if type(x) == LabeledPoint: + x_ = _serialize_double_vector(x.features) + else: + # Assume x is a single data point. + x_ = _serialize_double_vector(x) + return pythonAPI.predictDecisionTreeModel(self._java_model, x_) + + def numNodes(self): + return self._java_model.numNodes() + + def depth(self): + return self._java_model.depth() + + def __str__(self): + return self._java_model.toString() + + +class DecisionTree(object): + """ + Learning algorithm for a decision tree model for classification or regression. + + WARNING: This is an experimental API. It will probably be modified for Spark v1.2. + """ + + def run(self, data, datasetInfo): + """ + :param data: RDD of NumPy vectors, one per element, where the first + coordinate is the label and the rest is the feature vector. + Labels are integers {0,1,...,numClasses}. + :param datasetInfo: Dataset metadata + :return: DecisionTreeClassifierModel + """ + + @staticmethod + def trainClassifier(data, numClasses, categoricalFeaturesInfo={}, + impurity="gini", maxDepth=4, maxBins=100): + """ + Train a DecisionTreeModel for classification. + + :param data: RDD of NumPy vectors, one per element, where the first + coordinate is the label and the rest is the feature vector. + Labels are integers {0,1,...,numClasses}. + :param numClasses: Number of classes for classification. + :param categoricalFeaturesInfo: Map from categorical feature index to number of categories. + Any feature not in this map is treated as continuous. + :param impurity: Supported values: "entropy" or "gini" + :param maxDepth: Max depth of tree. + E.g., depth 0 means 1 leaf node. + Depth 1 means 1 internal node + 2 leaf nodes. + :param maxBins: Number of bins used for finding splits at each node. + :return: DecisionTreeModel + """ + return DecisionTree.train(data, "classification", numClasses, categoricalFeaturesInfo, + impurity, maxDepth, maxBins) + + @staticmethod + def trainRegressor(data, categoricalFeaturesInfo={}, + impurity="variance", maxDepth=4, maxBins=100): + """ + Train a DecisionTreeModel for regression. + + :param data: RDD of NumPy vectors, one per element, where the first + coordinate is the label and the rest is the feature vector. + Labels are real numbers. + :param categoricalFeaturesInfo: Map from categorical feature index to number of categories. + Any feature not in this map is treated as continuous. + :param impurity: Supported values: "variance" + :param maxDepth: Max depth of tree. + E.g., depth 0 means 1 leaf node. + Depth 1 means 1 internal node + 2 leaf nodes. + :param maxBins: Number of bins used for finding splits at each node. + :return: DecisionTreeModel + """ + return DecisionTree.train(data, "regression", 0, categoricalFeaturesInfo, + impurity, maxDepth, maxBins) + + + @staticmethod + def train(data, algo, numClasses, categoricalFeaturesInfo, impurity, maxDepth, maxBins=100): + """ + Train a DecisionTreeModel for classification or regression. + + :param data: RDD of NumPy vectors, one per element, where the first + coordinate is the label and the rest is the feature vector. + For classification, labels are integers {0,1,...,numClasses}. + For regression, labels are real numbers. + :param algo: "classification" or "regression" + :param numClasses: Number of classes for classification. 0 or 1 indicates regression. + :param categoricalFeaturesInfo: Map from categorical feature index to number of categories. + Any feature not in this map is treated as continuous. + :param impurity: For classification: "entropy" or "gini". For regression: "variance". + :param maxDepth: Max depth of tree. + E.g., depth 0 means 1 leaf node. + Depth 1 means 1 internal node + 2 leaf nodes. + :param maxBins: Number of bins used for finding splits at each node. + :return: DecisionTreeModel + """ + sc = data.context + dataBytes = _get_unmangled_labeled_point_rdd(data) + categoricalFeaturesInfoJMap = \ + MapConverter().convert(categoricalFeaturesInfo, sc._gateway._gateway_client) + model = sc._jvm.PythonMLLibAPI().trainDecisionTreeModel( + dataBytes._jrdd, algo, + numClasses, categoricalFeaturesInfoJMap, + impurity, maxDepth, maxBins) + return DecisionTreeModel(sc, model) + + +def _test(): + import doctest + globs = globals().copy() + globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) + (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) + globs['sc'].stop() + if failure_count: + exit(-1) + +if __name__ == "__main__": + _test() From 5f920a10b6114baa0744f55843969843b1f2babc Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Wed, 30 Jul 2014 15:24:55 -0700 Subject: [PATCH 02/21] Demonstration of bug before submitting fix: Updated DecisionTreeSuite so that 3 tests fail. Will describe bug in next commit. --- .../examples/mllib/DecisionTreeRunner.scala | 92 +++++++++++++++---- .../mllib/tree/model/DecisionTreeModel.scala | 28 ++++++ .../apache/spark/mllib/tree/model/Node.scala | 55 +++++++++++ .../spark/mllib/tree/DecisionTreeSuite.scala | 51 +++++++++- 4 files changed, 205 insertions(+), 21 deletions(-) diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala index 43f13fe24f0d0..30aded532b5ed 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala @@ -21,7 +21,6 @@ import scopt.OptionParser import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.SparkContext._ -import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{DecisionTree, impurity} import org.apache.spark.mllib.tree.configuration.{Algo, Strategy} @@ -36,6 +35,9 @@ import org.apache.spark.rdd.RDD * ./bin/spark-example org.apache.spark.examples.mllib.DecisionTreeRunner [options] * }}} * If you use it as a template to create your own app, please use `spark-submit` to submit your app. + * + * Note: This script treats all features as real-valued (not categorical). + * To include categorical features, modify categoricalFeaturesInfo. */ object DecisionTreeRunner { @@ -48,11 +50,13 @@ object DecisionTreeRunner { case class Params( input: String = null, + dataFormat: String = null, algo: Algo = Classification, numClassesForClassification: Int = 2, - maxDepth: Int = 5, + maxDepth: Int = 4, impurity: ImpurityType = Gini, - maxBins: Int = 100) + maxBins: Int = 100, + fracTest: Double = 0.2) def main(args: Array[String]) { val defaultParams = Params() @@ -69,25 +73,32 @@ object DecisionTreeRunner { opt[Int]("maxDepth") .text(s"max depth of the tree, default: ${defaultParams.maxDepth}") .action((x, c) => c.copy(maxDepth = x)) - opt[Int]("numClassesForClassification") - .text(s"number of classes for classification, " - + s"default: ${defaultParams.numClassesForClassification}") - .action((x, c) => c.copy(numClassesForClassification = x)) opt[Int]("maxBins") .text(s"max number of bins, default: ${defaultParams.maxBins}") .action((x, c) => c.copy(maxBins = x)) + opt[Double]("fracTest") + .text(s"fraction of data to hold out for testing, default: ${defaultParams.fracTest}") + .action((x, c) => c.copy(fracTest = x)) arg[String]("") .text("input paths to labeled examples in dense format (label,f0 f1 f2 ...)") .required() .action((x, c) => c.copy(input = x)) + arg[String]("") + .text("data format: dense/libsvm") + .required() + .action((x, c) => c.copy(dataFormat = x)) checkConfig { params => - if (params.algo == Classification && - (params.impurity == Gini || params.impurity == Entropy)) { - success - } else if (params.algo == Regression && params.impurity == Variance) { - success + if (params.fracTest < 0 || params.fracTest > 1) { + failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1].") } else { - failure(s"Algo ${params.algo} is not compatible with impurity ${params.impurity}.") + if (params.algo == Classification && + (params.impurity == Gini || params.impurity == Entropy)) { + success + } else if (params.algo == Regression && params.impurity == Variance) { + success + } else { + failure(s"Algo ${params.algo} is not compatible with impurity ${params.impurity}.") + } } } } @@ -100,16 +111,57 @@ object DecisionTreeRunner { } def run(params: Params) { + val conf = new SparkConf().setAppName("DecisionTreeRunner") val sc = new SparkContext(conf) // Load training data and cache it. - val examples = MLUtils.loadLabeledPoints(sc, params.input).cache() + val origExamples = params.dataFormat match { + case "dense" => MLUtils.loadLabeledPoints(sc, params.input).cache() + case "libsvm" => MLUtils.loadLibSVMFile(sc, params.input, multiclass = true).cache() + } + // For classification, re-index classes if needed. + val (examples, numClasses) = params.algo match { + case Classification => { + // classCounts: class --> # examples in class + val classCounts = origExamples.map(_.label).countByValue + val numClasses = classCounts.size + // classIndex: class --> index in 0,...,numClasses-1 + val classIndex = { + if (classCounts.keySet != Set[Double](0.0, 1.0)) { + classCounts.keys.toList.sorted.zipWithIndex.toMap + } else { + Map[Double, Int]() + } + } + val examples = { + if (classIndex.isEmpty) { + origExamples + } else { + origExamples.map(lp => LabeledPoint(classIndex(lp.label), lp.features)) + } + } + println(s"numClasses = $numClasses.") + println(s"Per-class example fractions, counts:") + println(s"Class\tFrac\tCount") + classCounts.keys.toList.sorted.foreach(c => { + val frac = classCounts(c) / (0.0 + examples.count()) + println(s"$c\t$frac\t${classCounts(c)}") + }) + (examples, numClasses) + } + case Regression => { + (origExamples, 0) + } + case _ => { + throw new IllegalArgumentException("Algo ${params.algo} not supported.") + } + } - val splits = examples.randomSplit(Array(0.8, 0.2)) + // Split into training, test. + val splits = examples.randomSplit(Array(1.0 - params.fracTest, params.fracTest)) val training = splits(0).cache() val test = splits(1).cache() - val numTraining = training.count() val numTest = test.count() @@ -129,17 +181,19 @@ object DecisionTreeRunner { impurity = impurityCalculator, maxDepth = params.maxDepth, maxBins = params.maxBins, - numClassesForClassification = params.numClassesForClassification) + numClassesForClassification = numClasses) val model = DecisionTree.train(training, strategy) + println(model) + if (params.algo == Classification) { val accuracy = accuracyScore(model, test) - println(s"Test accuracy = $accuracy.") + println(s"Test accuracy = $accuracy") } if (params.algo == Regression) { val mse = meanSquaredError(model, test) - println(s"Test mean squared error = $mse.") + println(s"Test mean squared error = $mse") } sc.stop() diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala index bf692ca8c4bd7..156c86b33966b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala @@ -50,4 +50,32 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable def predict(features: RDD[Vector]): RDD[Double] = { features.map(x => predict(x)) } + + /** + * Get number of nodes in tree, including leaf nodes. + */ + def numNodes: Int = { + topNode.numNodesRecursive + } + + /** + * Get depth of tree. + * E.g.: Depth 0 means 1 leaf node. Depth 1 means 1 internal node and 2 leaf nodes. + */ + def depth: Int = { + topNode.depthRecursive + } + + /** + * Print full model. + */ + override def toString: String = algo match { + case Classification => + s"DecisionTreeModel classifier\n" + topNode.toStringRecursive(2) + case Regression => + s"DecisionTreeModel regressor\n" + topNode.toStringRecursive(2) + case _ => throw new IllegalArgumentException( + s"DecisionTreeModel given unknown algo parameter: $algo.") + } + } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala index 682f213f411a7..b27d546728848 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala @@ -91,4 +91,59 @@ class Node ( } } } + + /** + * Get number of nodes in tree from this node, including leaf nodes. + */ + def numNodesRecursive: Int = { + if (isLeaf) { + 1 + } else { + 1 + leftNode.get.numNodesRecursive + rightNode.get.numNodesRecursive + } + } + + /** + * Get depth of tree from this node. + * E.g.: Depth 0 means this is a leaf node. + */ + def depthRecursive: Int = { + if (isLeaf) { + 0 + } else { + 1 + math.max(leftNode.get.depthRecursive, rightNode.get.depthRecursive) + } + } + + /** + * Recursive print function. + * @param indentFactor The number of spaces to add to each level of indentation. + */ + def toStringRecursive(indentFactor: Int = 0): String = { + + def splitToString(split: Split, left: Boolean) : String = { + split.featureType match { + case Continuous => if (left) { + s"(feature ${split.feature} <= ${split.threshold})" + } else { + s"(feature ${split.feature} > ${split.threshold})" + } + case Categorical => if (left) { + s"(feature ${split.feature} in ${split.categories})" + } else { + s"(feature ${split.feature} not in ${split.categories})" + } + } + } + val prefix: String = " " * indentFactor + if (isLeaf) { + prefix + s"Predict: $predict\n" + } else { + prefix + s"If ${splitToString(split.get, left=true)}\n" + + leftNode.get.toStringRecursive(indentFactor + 1) + + prefix + s"Else ${splitToString(split.get, left=false)}\n" + + rightNode.get.toStringRecursive(indentFactor + 1) + } + } + } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 5961a618c59d9..761b4e1fac5fd 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -20,8 +20,7 @@ package org.apache.spark.mllib.tree import org.scalatest.FunSuite import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance} -import org.apache.spark.mllib.tree.model.Filter -import org.apache.spark.mllib.tree.model.Split +import org.apache.spark.mllib.tree.model.{DecisionTreeModel, Filter, Split} import org.apache.spark.mllib.tree.configuration.{FeatureType, Strategy} import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.FeatureType._ @@ -31,6 +30,18 @@ import org.apache.spark.mllib.regression.LabeledPoint class DecisionTreeSuite extends FunSuite with LocalSparkContext { + def validateClassifier( + model: DecisionTreeModel, + input: Seq[LabeledPoint], + requiredAccuracy: Double) { + val predictions = input.map { x => model.predict(x.features) } + val numOffPredictions = predictions.zip(input).count { case (prediction, expected) => + prediction != expected.label + } + val accuracy = (input.length - numOffPredictions).toDouble / input.length + assert(accuracy >= requiredAccuracy) + } + test("split and bin calculation") { val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1() assert(arr.length === 1000) @@ -602,12 +613,44 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bestSplit.featureType === Categorical) } + test("stump with categorical variables for multiclass classification, with just enough bins") { + val maxBins = math.pow(2, 3 - 1).toInt // just enough bins to allow unordered features + val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlass() + val input = sc.parallelize(arr) + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5, + numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) + assert(strategy.isMulticlassClassification) + + val model = DecisionTree.train(input, strategy) + validateClassifier(model, arr, 1.0) + assert(model.numNodes === 3) + assert(model.depth === 1) + + val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) + val bestSplits = DecisionTree.findBestSplits(input, new Array(31), strategy, 0, + Array[List[Filter]](), splits, bins, 10) + + assert(bestSplits.length === 1) + val bestSplit = bestSplits(0)._1 + assert(bestSplit.feature === 0) + assert(bestSplit.categories.length === 1) + assert(bestSplit.categories.contains(1)) + assert(bestSplit.featureType === Categorical) + val gain = bestSplits(0)._2 + assert(gain.leftImpurity == 0) + assert(gain.rightImpurity == 0) + } + test("stump with continuous variables for multiclass classification") { val arr = DecisionTreeSuite.generateContinuousDataPointsForMulticlass() val input = sc.parallelize(arr) val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5, numClassesForClassification = 3) assert(strategy.isMulticlassClassification) + + val model = DecisionTree.train(input, strategy) + validateClassifier(model, arr, 0.9) + val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) val bestSplits = DecisionTree.findBestSplits(input, new Array(31), strategy, 0, Array[List[Filter]](), splits, bins, 10) @@ -628,6 +671,10 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5, numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 3)) assert(strategy.isMulticlassClassification) + + val model = DecisionTree.train(input, strategy) + validateClassifier(model, arr, 0.9) + val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) val bestSplits = DecisionTree.findBestSplits(input, new Array(31), strategy, 0, Array[List[Filter]](), splits, bins, 10) From 2283df878178d3b8c86ecde1d4220076af25b72f Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Wed, 30 Jul 2014 15:53:14 -0700 Subject: [PATCH 03/21] 2 bug fixes. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Indexing was inconsistent for aggregate calculations for unordered features (in multiclass classification with categorical features, where the features had few enough values such that they could be considered unordered, i.e., isSpaceSufficientForAllCategoricalSplits=true). * updateBinForUnorderedFeature indexed agg as (node, feature, featureValue, binIndex), where ** featureValue was from arr (so it was a feature value) ** binIndex was in [0,…, 2^(maxFeatureValue-1)-1) * The rest of the code indexed agg as (node, feature, binIndex, label). * Corrected this bug by changing updateBinForUnorderedFeature to use the second indexing pattern. Unit tests in DecisionTreeSuite * Updated a few tests to train a model and test its training accuracy, which catches the indexing bug from updateBinForUnorderedFeature() discussed above. * Added new test (“stump with categorical variables for multiclass classification, with just enough bins”) to test bin extremes. Bug fix: calculateGainForSplit (for classification): * It used to return dummy prediction values when either the right or left children had 0 weight. These were incorrect for multiclass classification. It has been corrected. Updated impurities to allow for count = 0. This was related to the above bug fix for calculateGainForSplit (for classification). Small updates to documentation and coding style. --- .../spark/mllib/tree/DecisionTree.scala | 103 +++++++++--------- .../spark/mllib/tree/impurity/Entropy.scala | 6 +- .../spark/mllib/tree/impurity/Gini.scala | 6 +- .../spark/mllib/tree/impurity/Impurity.scala | 4 +- .../spark/mllib/tree/impurity/Variance.scala | 6 +- 5 files changed, 66 insertions(+), 59 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index ad32e3f4560fe..9eddf9f30835c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -598,9 +598,12 @@ object DecisionTree extends Serializable with Logging { // Find feature bins for all nodes at a level. val binMappedRDD = input.map(x => findBinsForLevel(x)) - def updateBinForOrderedFeature(arr: Array[Double], agg: Array[Double], nodeIndex: Int, - label: Double, featureIndex: Int) = { - + def updateBinForOrderedFeature( + arr: Array[Double], + agg: Array[Double], + nodeIndex: Int, + label: Double, + featureIndex: Int) = { // Find the bin index for this feature. val arrShift = 1 + numFeatures * nodeIndex val arrIndex = arrShift + featureIndex @@ -612,27 +615,31 @@ object DecisionTree extends Serializable with Logging { agg(aggIndex + labelInt) = agg(aggIndex + labelInt) + 1 } - def updateBinForUnorderedFeature(nodeIndex: Int, featureIndex: Int, arr: Array[Double], - label: Double, agg: Array[Double], rightChildShift: Int) = { + def updateBinForUnorderedFeature( + nodeIndex: Int, + featureIndex: Int, + arr: Array[Double], + label: Double, + agg: Array[Double], + rightChildShift: Int) = { // Find the bin index for this feature. - val arrShift = 1 + numFeatures * nodeIndex - val arrIndex = arrShift + featureIndex + val arrIndex = 1 + numFeatures * nodeIndex + featureIndex + val featureValue = arr(arrIndex).toInt // Update the left or right count for one bin. - val aggShift = numClasses * numBins * numFeatures * nodeIndex - val aggIndex - = aggShift + numClasses * featureIndex * numBins + arr(arrIndex).toInt * numClasses + val aggShift = + numClasses * numBins * numFeatures * nodeIndex + + numClasses * numBins * featureIndex + + label.toInt // Find all matching bins and increment their values val featureCategories = strategy.categoricalFeaturesInfo(featureIndex) val numCategoricalBins = math.pow(2.0, featureCategories - 1).toInt - 1 var binIndex = 0 while (binIndex < numCategoricalBins) { - val labelInt = label.toInt - if (bins(featureIndex)(binIndex).highSplit.categories.contains(labelInt)) { - agg(aggIndex + binIndex) - = agg(aggIndex + binIndex) + 1 + val aggIndex = aggShift + binIndex * numClasses + if (bins(featureIndex)(binIndex).highSplit.categories.contains(featureValue)) { + agg(aggIndex) += 1 } else { - agg(rightChildShift + aggIndex + binIndex) - = agg(rightChildShift + aggIndex + binIndex) + 1 + agg(rightChildShift + aggIndex) += 1 } binIndex += 1 } @@ -815,20 +822,10 @@ object DecisionTree extends Serializable with Logging { topImpurity: Double): InformationGainStats = { strategy.algo match { case Classification => - var classIndex = 0 - val leftCounts: Array[Double] = new Array[Double](numClasses) - val rightCounts: Array[Double] = new Array[Double](numClasses) - var leftTotalCount = 0.0 - var rightTotalCount = 0.0 - while (classIndex < numClasses) { - val leftClassCount = leftNodeAgg(featureIndex)(splitIndex)(classIndex) - val rightClassCount = rightNodeAgg(featureIndex)(splitIndex)(classIndex) - leftCounts(classIndex) = leftClassCount - leftTotalCount += leftClassCount - rightCounts(classIndex) = rightClassCount - rightTotalCount += rightClassCount - classIndex += 1 - } + val leftCounts: Array[Double] = leftNodeAgg(featureIndex)(splitIndex) + val rightCounts: Array[Double] = rightNodeAgg(featureIndex)(splitIndex) + var leftTotalCount = leftCounts.sum + var rightTotalCount = rightCounts.sum val impurity = { if (level > 0) { @@ -845,33 +842,15 @@ object DecisionTree extends Serializable with Logging { } } - if (leftTotalCount == 0) { - return new InformationGainStats(0, topImpurity, topImpurity, Double.MinValue, 1) - } - if (rightTotalCount == 0) { - return new InformationGainStats(0, topImpurity, Double.MinValue, topImpurity, 1) - } - - val leftImpurity = strategy.impurity.calculate(leftCounts, leftTotalCount) - val rightImpurity = strategy.impurity.calculate(rightCounts, rightTotalCount) - - val leftWeight = leftTotalCount / (leftTotalCount + rightTotalCount) - val rightWeight = rightTotalCount / (leftTotalCount + rightTotalCount) - - val gain = { - if (level > 0) { - impurity - leftWeight * leftImpurity - rightWeight * rightImpurity - } else { - impurity - leftWeight * leftImpurity - rightWeight * rightImpurity - } - } - val totalCount = leftTotalCount + rightTotalCount + if (totalCount == 0) { + // Return arbitrary prediction. + return new InformationGainStats(0, topImpurity, topImpurity, topImpurity, 0) + } // Sum of count for each label - val leftRightCounts: Array[Double] - = leftCounts.zip(rightCounts) - .map{case (leftCount, rightCount) => leftCount + rightCount} + val leftRightCounts: Array[Double] = leftCounts.zip(rightCounts).map { + case (leftCount, rightCount) => leftCount + rightCount } def indexOfLargestArrayElement(array: Array[Double]): Int = { val result = array.foldLeft(-1, Double.MinValue, 0) { @@ -885,6 +864,22 @@ object DecisionTree extends Serializable with Logging { val predict = indexOfLargestArrayElement(leftRightCounts) val prob = leftRightCounts(predict) / totalCount + val leftImpurity = if (leftTotalCount == 0) { + topImpurity + } else { + strategy.impurity.calculate(leftCounts, leftTotalCount) + } + val rightImpurity = if (rightTotalCount == 0) { + topImpurity + } else { + strategy.impurity.calculate(rightCounts, rightTotalCount) + } + + val leftWeight = leftTotalCount / totalCount + val rightWeight = rightTotalCount / totalCount + + val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity + new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict, prob) case Regression => val leftCount = leftNodeAgg(featureIndex)(splitIndex)(0) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala index a0e2d91762782..9297c20596527 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala @@ -34,10 +34,13 @@ object Entropy extends Impurity { * information calculation for multiclass classification * @param counts Array[Double] with counts for each label * @param totalCount sum of counts for all labels - * @return information value + * @return information value, or 0 if totalCount = 0 */ @DeveloperApi override def calculate(counts: Array[Double], totalCount: Double): Double = { + if (totalCount == 0) { + return 0 + } val numClasses = counts.length var impurity = 0.0 var classIndex = 0 @@ -58,6 +61,7 @@ object Entropy extends Impurity { * @param count number of instances * @param sum sum of labels * @param sumSquares summation of squares of the labels + * @return information value, or 0 if count = 0 */ @DeveloperApi override def calculate(count: Double, sum: Double, sumSquares: Double): Double = diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala index 48144b5e6d1e4..2874bcf496484 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala @@ -33,10 +33,13 @@ object Gini extends Impurity { * information calculation for multiclass classification * @param counts Array[Double] with counts for each label * @param totalCount sum of counts for all labels - * @return information value + * @return information value, or 0 if totalCount = 0 */ @DeveloperApi override def calculate(counts: Array[Double], totalCount: Double): Double = { + if (totalCount == 0) { + return 0 + } val numClasses = counts.length var impurity = 1.0 var classIndex = 0 @@ -54,6 +57,7 @@ object Gini extends Impurity { * @param count number of instances * @param sum sum of labels * @param sumSquares summation of squares of the labels + * @return information value, or 0 if count = 0 */ @DeveloperApi override def calculate(count: Double, sum: Double, sumSquares: Double): Double = diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala index 7b2a9320cc21d..92b0c7b4a6fbc 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala @@ -31,7 +31,7 @@ trait Impurity extends Serializable { * information calculation for multiclass classification * @param counts Array[Double] with counts for each label * @param totalCount sum of counts for all labels - * @return information value + * @return information value, or 0 if totalCount = 0 */ @DeveloperApi def calculate(counts: Array[Double], totalCount: Double): Double @@ -42,7 +42,7 @@ trait Impurity extends Serializable { * @param count number of instances * @param sum sum of labels * @param sumSquares summation of squares of the labels - * @return information value + * @return information value, or 0 if count = 0 */ @DeveloperApi def calculate(count: Double, sum: Double, sumSquares: Double): Double diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala index 97149a99ead59..698a1a2a8e899 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala @@ -31,7 +31,7 @@ object Variance extends Impurity { * information calculation for multiclass classification * @param counts Array[Double] with counts for each label * @param totalCount sum of counts for all labels - * @return information value + * @return information value, or 0 if totalCount = 0 */ @DeveloperApi override def calculate(counts: Array[Double], totalCount: Double): Double = @@ -43,9 +43,13 @@ object Variance extends Impurity { * @param count number of instances * @param sum sum of labels * @param sumSquares summation of squares of the labels + * @return information value, or 0 if count = 0 */ @DeveloperApi override def calculate(count: Double, sum: Double, sumSquares: Double): Double = { + if (count == 0) { + return 0 + } val squaredLoss = sumSquares - (sum * sum) / count squaredLoss / count } From 8ea8750cd5eeefa87d937ca4214a5f548dd2e6a4 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Wed, 30 Jul 2014 17:05:49 -0700 Subject: [PATCH 04/21] Bug fix: Off-by-1 when finding thresholds for splits for continuous features. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Exhibited bug in new test in DecisionTreeSuite: “stump with 1 continuous variable for binary classification, to check off-by-1 error” * Description: When finding thresholds for possible splits for continuous features in DecisionTree.findSplitsBins, the thresholds were set according to individual training examples’ feature values. This can cause problems for small datasets, when the number of training examples equals numBins. * Fix: The threshold is set to be the average of 2 consecutive (sorted) examples’ feature values. E.g.: If the old code set the threshold using example i, the new code sets the threshold using examples i and i+1. * Note: In 4 DecisionTreeSuite tests with all labels identical, removed check of threshold since it is somewhat arbitrary. --- .../spark/mllib/tree/DecisionTree.scala | 6 ++++-- .../spark/mllib/tree/DecisionTreeSuite.scala | 20 +++++++++++++++---- 2 files changed, 20 insertions(+), 6 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 9eddf9f30835c..d10fe53bf9021 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -1289,8 +1289,10 @@ object DecisionTree extends Serializable with Logging { val stride: Double = numSamples.toDouble / numBins logDebug("stride = " + stride) for (index <- 0 until numBins - 1) { - val sampleIndex = (index + 1) * stride.toInt - val split = new Split(featureIndex, featureSamples(sampleIndex), Continuous, List()) + val sampleIndex = index * stride.toInt + // Set threshold halfway in between 2 samples. + val threshold = (featureSamples(sampleIndex) + featureSamples(sampleIndex + 1)) / 2.0 + val split = new Split(featureIndex, threshold, Continuous, List()) splits(featureIndex)(index) = split } } else { // Categorical feature diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 761b4e1fac5fd..a06dbec92de2b 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -471,7 +471,6 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { Array[List[Filter]](), splits, bins, 10) assert(bestSplits.length === 1) assert(bestSplits(0)._1.feature === 0) - assert(bestSplits(0)._1.threshold === 10) assert(bestSplits(0)._2.gain === 0) assert(bestSplits(0)._2.leftImpurity === 0) assert(bestSplits(0)._2.rightImpurity === 0) @@ -494,7 +493,6 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { Array[List[Filter]](), splits, bins, 10) assert(bestSplits.length === 1) assert(bestSplits(0)._1.feature === 0) - assert(bestSplits(0)._1.threshold === 10) assert(bestSplits(0)._2.gain === 0) assert(bestSplits(0)._2.leftImpurity === 0) assert(bestSplits(0)._2.rightImpurity === 0) @@ -518,7 +516,6 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { Array[List[Filter]](), splits, bins, 10) assert(bestSplits.length === 1) assert(bestSplits(0)._1.feature === 0) - assert(bestSplits(0)._1.threshold === 10) assert(bestSplits(0)._2.gain === 0) assert(bestSplits(0)._2.leftImpurity === 0) assert(bestSplits(0)._2.rightImpurity === 0) @@ -542,7 +539,6 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { Array[List[Filter]](), splits, bins, 10) assert(bestSplits.length === 1) assert(bestSplits(0)._1.feature === 0) - assert(bestSplits(0)._1.threshold === 10) assert(bestSplits(0)._2.gain === 0) assert(bestSplits(0)._2.leftImpurity === 0) assert(bestSplits(0)._2.rightImpurity === 0) @@ -613,6 +609,22 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bestSplit.featureType === Categorical) } + test("stump with 1 continuous variable for binary classification, to check off-by-1 error") { + val arr = new Array[LabeledPoint](4) + arr(0) = new LabeledPoint(0.0, Vectors.dense(0.0)) + arr(1) = new LabeledPoint(1.0, Vectors.dense(1.0)) + arr(2) = new LabeledPoint(1.0, Vectors.dense(2.0)) + arr(3) = new LabeledPoint(1.0, Vectors.dense(3.0)) + val input = sc.parallelize(arr) + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5, + numClassesForClassification = 2) + + val model = DecisionTree.train(input, strategy) + validateClassifier(model, arr, 1.0) + assert(model.numNodes === 3) + assert(model.depth === 1) + } + test("stump with categorical variables for multiclass classification, with just enough bins") { val maxBins = math.pow(2, 3 - 1).toInt // just enough bins to allow unordered features val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlass() From 8e227ea826d6b38dc47e9a90ccf6683348c6dab0 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Wed, 30 Jul 2014 17:18:55 -0700 Subject: [PATCH 05/21] Changed Strategy so it only requires numClassesForClassification >= 2 for classification --- .../mllib/tree/configuration/Strategy.scala | 4 ++- python/pyspark/mllib/tree.py | 36 ------------------- 2 files changed, 3 insertions(+), 37 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index 7c027ac2fda6b..3b003fef4bdb2 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -52,7 +52,9 @@ class Strategy ( val categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](), val maxMemoryInMB: Int = 128) extends Serializable { - require(numClassesForClassification >= 2) + if (algo == Classification) { + require(numClassesForClassification >= 2) + } val isMulticlassClassification = numClassesForClassification > 2 val isMulticlassWithCategoricalFeatures = isMulticlassClassification && (categoricalFeaturesInfo.size > 0) diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py index c04f17e1be23b..4eb9d8e739ecf 100644 --- a/python/pyspark/mllib/tree.py +++ b/python/pyspark/mllib/tree.py @@ -35,42 +35,6 @@ class DecisionTreeModel(object): A decision tree model for classification or regression. WARNING: This is an experimental API. It will probably be modified for Spark v1.2. - - # TODO: UPDATE: - >>> from numpy import array, ndarray - >>> from pyspark.mllib.regression import LabeledPoint - >>> from pyspark.mllib.tree import DecisionTree - >>> from pyspark.mllib.linalg import SparseVector - >>> data = [ - ... LabeledPoint(0.0, [0.0]), - ... LabeledPoint(1.0, [1.0]), - ... LabeledPoint(1.0, [2.0]), - ... LabeledPoint(1.0, [3.0]) - ... ] - >>> datasetInfo = DatasetInfo(2, 1) - >>> params = DecisionTreeClassifier.defaultParams() - >>> dtLearner = DecisionTreeClassifier(params) - >>> model = dtLearner.run(sc.parallelize(data), datasetInfo) - >>> model.predict(array([1.0])) > 0 - True - >>> model.predict(array([0.0])) == 0 - True - >>> sparse_data = [ - ... LabeledPoint(0.0, SparseVector(2, {0: 0.0})), - ... LabeledPoint(1.0, SparseVector(2, {1: 1.0})), - ... LabeledPoint(0.0, SparseVector(2, {0: 0.0})), - ... LabeledPoint(1.0, SparseVector(2, {1: 2.0})) - ... ] - >>> datasetInfo = DatasetInfo(2, 2) - >>> model = dtLearner.run(sc.parallelize(sparse_data), datasetInfo) - >>> model.predict(array([0.0, 1.0])) == 1 - True - >>> model.predict(array([0.0, 0.0])) == 0 - True - >>> model.predict(SparseVector(2, {1: 1.0})) == 1 - True - >>> model.predict(SparseVector(2, {1: 0.0})) == 0 - True """ def __init__(self, sc, java_model): From da50db749f54a63565440d6c42f78373f1f2a2ac Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Wed, 30 Jul 2014 17:32:10 -0700 Subject: [PATCH 06/21] Added one more test to DecisionTreeSuite: stump with 2 continuous variables for binary classification. Caused problems in past, but fixed now. --- .../spark/mllib/tree/DecisionTreeSuite.scala | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index a06dbec92de2b..be4f761997733 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -625,6 +625,24 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(model.depth === 1) } + test("stump with 2 continuous variables for binary classification") { + val arr = new Array[LabeledPoint](4) + arr(0) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))) + arr(1) = new LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0)))) + arr(2) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))) + arr(3) = new LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 2.0)))) + + val input = sc.parallelize(arr) + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5, + numClassesForClassification = 2) + + val model = DecisionTree.train(input, strategy) + validateClassifier(model, arr, 1.0) + assert(model.numNodes === 3) + assert(model.depth === 1) + assert(model.topNode.split.get.feature === 1) + } + test("stump with categorical variables for multiclass classification, with just enough bins") { val maxBins = math.pow(2, 3 - 1).toInt // just enough bins to allow unordered features val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlass() From 59750f87c974299720ec556908c7e29b131d3476 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Thu, 31 Jul 2014 11:08:46 -0700 Subject: [PATCH 07/21] * Updated Strategy to check numClassesForClassification only if algo=Classification. * Updates based on comments: ** DecisionTreeRunner *** Made dataFormat arg default to libsvm ** Small cleanups ** tree.Node: Made recursive helper methods private, and renamed them. --- .../examples/mllib/DecisionTreeRunner.scala | 34 +++++++++---------- .../spark/mllib/tree/DecisionTree.scala | 10 +++--- .../mllib/tree/configuration/Strategy.scala | 4 ++- .../mllib/tree/model/DecisionTreeModel.scala | 8 ++--- .../apache/spark/mllib/tree/model/Node.scala | 25 +++++++------- .../spark/mllib/tree/DecisionTreeSuite.scala | 12 +++---- 6 files changed, 49 insertions(+), 44 deletions(-) diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala index 6c338db65dde1..5b34a3a41c7f9 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala @@ -50,9 +50,8 @@ object DecisionTreeRunner { case class Params( input: String = null, - dataFormat: String = null, + dataFormat: String = "libsvm", algo: Algo = Classification, - numClassesForClassification: Int = 2, maxDepth: Int = 4, impurity: ImpurityType = Gini, maxBins: Int = 100, @@ -79,14 +78,13 @@ object DecisionTreeRunner { opt[Double]("fracTest") .text(s"fraction of data to hold out for testing, default: ${defaultParams.fracTest}") .action((x, c) => c.copy(fracTest = x)) + opt[String]("") + .text("data format: libsvm (default), dense (deprecated in Spark v1.1)") + .action((x, c) => c.copy(dataFormat = x)) arg[String]("") .text("input paths to labeled examples in dense format (label,f0 f1 f2 ...)") .required() .action((x, c) => c.copy(input = x)) - arg[String]("") - .text("data format: dense/libsvm") - .required() - .action((x, c) => c.copy(dataFormat = x)) checkConfig { params => if (params.fracTest < 0 || params.fracTest > 1) { failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1].") @@ -118,36 +116,38 @@ object DecisionTreeRunner { // Load training data and cache it. val origExamples = params.dataFormat match { case "dense" => MLUtils.loadLabeledPoints(sc, params.input).cache() - case "libsvm" => MLUtils.loadLibSVMFile(sc, params.input, multiclass = true).cache() + case "libsvm" => MLUtils.loadLibSVMFile(sc, params.input).cache() } // For classification, re-index classes if needed. val (examples, numClasses) = params.algo match { case Classification => { // classCounts: class --> # examples in class - val classCounts = origExamples.map(_.label).countByValue + val classCounts = origExamples.map(_.label).countByValue() + val sortedClasses = classCounts.keys.toList.sorted val numClasses = classCounts.size - // classIndex: class --> index in 0,...,numClasses-1 - val classIndex = { - if (classCounts.keySet != Set[Double](0.0, 1.0)) { - classCounts.keys.toList.sorted.zipWithIndex.toMap + // classIndexMap: class --> index in 0,...,numClasses-1 + val classIndexMap = { + if (classCounts.keySet != Set(0.0, 1.0)) { + sortedClasses.zipWithIndex.toMap } else { Map[Double, Int]() } } val examples = { - if (classIndex.isEmpty) { + if (classIndexMap.isEmpty) { origExamples } else { - origExamples.map(lp => LabeledPoint(classIndex(lp.label), lp.features)) + origExamples.map(lp => LabeledPoint(classIndexMap(lp.label), lp.features)) } } + val numExamples = examples.count() println(s"numClasses = $numClasses.") println(s"Per-class example fractions, counts:") println(s"Class\tFrac\tCount") - classCounts.keys.toList.sorted.foreach(c => { - val frac = classCounts(c) / (0.0 + examples.count()) + sortedClasses.foreach { c => { + val frac = classCounts(c) / (0.0 + numExamples) println(s"$c\t$frac\t${classCounts(c)}") - }) + }} (examples, numClasses) } case Regression => { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index d10fe53bf9021..253a4bbe424eb 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -824,8 +824,8 @@ object DecisionTree extends Serializable with Logging { case Classification => val leftCounts: Array[Double] = leftNodeAgg(featureIndex)(splitIndex) val rightCounts: Array[Double] = rightNodeAgg(featureIndex)(splitIndex) - var leftTotalCount = leftCounts.sum - var rightTotalCount = rightCounts.sum + val leftTotalCount = leftCounts.sum + val rightTotalCount = rightCounts.sum val impurity = { if (level > 0) { @@ -849,8 +849,10 @@ object DecisionTree extends Serializable with Logging { } // Sum of count for each label - val leftRightCounts: Array[Double] = leftCounts.zip(rightCounts).map { - case (leftCount, rightCount) => leftCount + rightCount } + val leftRightCounts: Array[Double] = + leftCounts.zip(rightCounts).map { case (leftCount, rightCount) => + leftCount + rightCount + } def indexOfLargestArrayElement(array: Array[Double]): Int = { val result = array.foldLeft(-1, Double.MinValue, 0) { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index 7c027ac2fda6b..3b003fef4bdb2 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -52,7 +52,9 @@ class Strategy ( val categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](), val maxMemoryInMB: Int = 128) extends Serializable { - require(numClassesForClassification >= 2) + if (algo == Classification) { + require(numClassesForClassification >= 2) + } val isMulticlassClassification = numClassesForClassification > 2 val isMulticlassWithCategoricalFeatures = isMulticlassClassification && (categoricalFeaturesInfo.size > 0) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala index 156c86b33966b..d558ab19dfb35 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala @@ -55,7 +55,7 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable * Get number of nodes in tree, including leaf nodes. */ def numNodes: Int = { - topNode.numNodesRecursive + 1 + topNode.numDescendants } /** @@ -63,7 +63,7 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable * E.g.: Depth 0 means 1 leaf node. Depth 1 means 1 internal node and 2 leaf nodes. */ def depth: Int = { - topNode.depthRecursive + topNode.subtreeDepth } /** @@ -71,9 +71,9 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable */ override def toString: String = algo match { case Classification => - s"DecisionTreeModel classifier\n" + topNode.toStringRecursive(2) + s"DecisionTreeModel classifier\n" + topNode.subtreeToString(2) case Regression => - s"DecisionTreeModel regressor\n" + topNode.toStringRecursive(2) + s"DecisionTreeModel regressor\n" + topNode.subtreeToString(2) case _ => throw new IllegalArgumentException( s"DecisionTreeModel given unknown algo parameter: $algo.") } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala index b27d546728848..944f11c2c2e4f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala @@ -93,13 +93,14 @@ class Node ( } /** - * Get number of nodes in tree from this node, including leaf nodes. + * Get the number of nodes in tree below this node, including leaf nodes. + * E.g., if this is a leaf, returns 0. If both children are leaves, returns 2. */ - def numNodesRecursive: Int = { + private[tree] def numDescendants: Int = { if (isLeaf) { - 1 + 0 } else { - 1 + leftNode.get.numNodesRecursive + rightNode.get.numNodesRecursive + 2 + leftNode.get.numDescendants + rightNode.get.numDescendants } } @@ -107,11 +108,11 @@ class Node ( * Get depth of tree from this node. * E.g.: Depth 0 means this is a leaf node. */ - def depthRecursive: Int = { + private[tree] def subtreeDepth: Int = { if (isLeaf) { 0 } else { - 1 + math.max(leftNode.get.depthRecursive, rightNode.get.depthRecursive) + 1 + math.max(leftNode.get.subtreeDepth, rightNode.get.subtreeDepth) } } @@ -119,9 +120,9 @@ class Node ( * Recursive print function. * @param indentFactor The number of spaces to add to each level of indentation. */ - def toStringRecursive(indentFactor: Int = 0): String = { + private[tree] def subtreeToString(indentFactor: Int = 0): String = { - def splitToString(split: Split, left: Boolean) : String = { + def splitToString(split: Split, left: Boolean): String = { split.featureType match { case Continuous => if (left) { s"(feature ${split.feature} <= ${split.threshold})" @@ -129,9 +130,9 @@ class Node ( s"(feature ${split.feature} > ${split.threshold})" } case Categorical => if (left) { - s"(feature ${split.feature} in ${split.categories})" + s"(feature ${split.feature} in ${split.categories.mkString("{",",","}")})" } else { - s"(feature ${split.feature} not in ${split.categories})" + s"(feature ${split.feature} not in ${split.categories.mkString("{",",","}")})" } } } @@ -140,9 +141,9 @@ class Node ( prefix + s"Predict: $predict\n" } else { prefix + s"If ${splitToString(split.get, left=true)}\n" + - leftNode.get.toStringRecursive(indentFactor + 1) + + leftNode.get.subtreeToString(indentFactor + 1) + prefix + s"Else ${splitToString(split.get, left=false)}\n" + - rightNode.get.toStringRecursive(indentFactor + 1) + rightNode.get.subtreeToString(indentFactor + 1) } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index be4f761997733..973e2f03bda29 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -34,7 +34,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { model: DecisionTreeModel, input: Seq[LabeledPoint], requiredAccuracy: Double) { - val predictions = input.map { x => model.predict(x.features) } + val predictions = input.map(x => model.predict(x.features)) val numOffPredictions = predictions.zip(input).count { case (prediction, expected) => prediction != expected.label } @@ -247,7 +247,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { test("extract categories from a number for multiclass classification") { val l = DecisionTree.extractMultiClassCategories(13, 10) assert(l.length === 3) - assert(List(3.0, 2.0, 0.0).toSeq == l.toSeq) + assert(List(3.0, 2.0, 0.0).toSeq === l.toSeq) } test("split and bin calculations for unordered categorical variables with multiclass " + @@ -424,7 +424,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val stats = bestSplits(0)._2 assert(stats.gain > 0) assert(stats.predict === 1) - assert(stats.prob == 0.6) + assert(stats.prob === 0.6) assert(stats.impurity > 0.2) } @@ -450,7 +450,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val stats = bestSplits(0)._2 assert(stats.gain > 0) - assert(stats.predict == 0.6) + assert(stats.predict === 0.6) assert(stats.impurity > 0.2) } @@ -667,8 +667,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bestSplit.categories.contains(1)) assert(bestSplit.featureType === Categorical) val gain = bestSplits(0)._2 - assert(gain.leftImpurity == 0) - assert(gain.rightImpurity == 0) + assert(gain.leftImpurity === 0) + assert(gain.rightImpurity === 0) } test("stump with continuous variables for multiclass classification") { From 376dca2c848739b1536e6ee8ddbc55043d1eef7a Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Thu, 31 Jul 2014 11:27:18 -0700 Subject: [PATCH 08/21] Updated meaning of maxDepth by 1 to fit scikit-learn and rpart. * In code, replaced usages of maxDepth <-- maxDepth + 1 * In params, replace settings of maxDepth <-- maxDepth - 1 --- .../spark/mllib/tree/DecisionTree.scala | 17 +++++++----- .../mllib/tree/configuration/Strategy.scala | 3 ++- .../spark/mllib/tree/DecisionTreeSuite.scala | 26 +++++++++---------- 3 files changed, 25 insertions(+), 21 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 253a4bbe424eb..7dde7d0116a12 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -60,7 +60,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo // depth of the decision tree val maxDepth = strategy.maxDepth // the max number of nodes possible given the depth of the tree - val maxNumNodes = math.pow(2, maxDepth).toInt - 1 + val maxNumNodes = math.pow(2, maxDepth + 1).toInt - 1 // Initialize an array to hold filters applied to points for each node. val filters = new Array[List[Filter]](maxNumNodes) // The filter at the top node is an empty list. @@ -100,7 +100,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo var level = 0 var break = false - while (level < maxDepth && !break) { + while (level <= maxDepth && !break) { logDebug("#####################################") logDebug("level = " + level) @@ -152,7 +152,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo val split = nodeSplitStats._1 val stats = nodeSplitStats._2 val nodeIndex = math.pow(2, level).toInt - 1 + index - val isLeaf = (stats.gain <= 0) || (level == strategy.maxDepth - 1) + val isLeaf = (stats.gain <= 0) || (level == strategy.maxDepth) val node = new Node(nodeIndex, stats.predict, isLeaf, Some(split), None, None, Some(stats)) logDebug("Node = " + node) nodes(nodeIndex) = node @@ -173,7 +173,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo while (i <= 1) { // Calculate the index of the node from the node level and the index at the current level. val nodeIndex = math.pow(2, level + 1).toInt - 1 + 2 * index + i - if (level < maxDepth - 1) { + if (level < maxDepth) { val impurity = if (i == 0) { nodeSplitStats._2.leftImpurity } else { @@ -223,7 +223,8 @@ object DecisionTree extends Serializable with Logging { * training data * @param algo algorithm, classification or regression * @param impurity impurity criterion used for information gain calculation - * @param maxDepth maxDepth maximum depth of the tree + * @param maxDepth Maximum depth of the tree. + * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. * @return a DecisionTreeModel that can be used for prediction */ def train( @@ -245,7 +246,8 @@ object DecisionTree extends Serializable with Logging { * training data * @param algo algorithm, classification or regression * @param impurity impurity criterion used for information gain calculation - * @param maxDepth maxDepth maximum depth of the tree + * @param maxDepth Maximum depth of the tree. + * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. * @param numClassesForClassification number of classes for classification. Default value of 2. * @return a DecisionTreeModel that can be used for prediction */ @@ -270,7 +272,8 @@ object DecisionTree extends Serializable with Logging { * training data for DecisionTree * @param algo classification or regression * @param impurity criterion used for information gain calculation - * @param maxDepth maximum depth of the tree + * @param maxDepth Maximum depth of the tree. + * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. * @param numClassesForClassification number of classes for classification. Default value of 2. * @param maxBins maximum number of bins used for splitting features * @param quantileCalculationStrategy algorithm for calculating quantiles diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index 3b003fef4bdb2..5c65b537b6867 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -27,7 +27,8 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ * Stores all the configuration options for tree construction * @param algo classification or regression * @param impurity criterion used for information gain calculation - * @param maxDepth maximum depth of the tree + * @param maxDepth Maximum depth of the tree. + * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. * @param numClassesForClassification number of classes for classification. Default value is 2 * leads to binary classification * @param maxBins maximum number of bins used for splitting features diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 973e2f03bda29..10462db700628 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -61,7 +61,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val strategy = new Strategy( Classification, Gini, - maxDepth = 3, + maxDepth = 2, numClassesForClassification = 2, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 2, 1-> 2)) @@ -141,7 +141,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val strategy = new Strategy( Classification, Gini, - maxDepth = 3, + maxDepth = 2, numClassesForClassification = 2, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) @@ -258,7 +258,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val strategy = new Strategy( Classification, Gini, - maxDepth = 3, + maxDepth = 2, numClassesForClassification = 100, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) @@ -352,7 +352,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val strategy = new Strategy( Classification, Gini, - maxDepth = 3, + maxDepth = 2, numClassesForClassification = 100, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 10, 1-> 10)) @@ -408,7 +408,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { Classification, Gini, numClassesForClassification = 2, - maxDepth = 3, + maxDepth = 2, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) @@ -435,7 +435,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val strategy = new Strategy( Regression, Variance, - maxDepth = 3, + maxDepth = 2, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) val (splits, bins) = DecisionTree.findSplitsBins(rdd,strategy) @@ -594,7 +594,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { test("stump with categorical variables for multiclass classification") { val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlass() val input = sc.parallelize(arr) - val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5, + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) assert(strategy.isMulticlassClassification) val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) @@ -616,7 +616,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { arr(2) = new LabeledPoint(1.0, Vectors.dense(2.0)) arr(3) = new LabeledPoint(1.0, Vectors.dense(3.0)) val input = sc.parallelize(arr) - val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5, + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, numClassesForClassification = 2) val model = DecisionTree.train(input, strategy) @@ -633,7 +633,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { arr(3) = new LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 2.0)))) val input = sc.parallelize(arr) - val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5, + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, numClassesForClassification = 2) val model = DecisionTree.train(input, strategy) @@ -647,7 +647,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val maxBins = math.pow(2, 3 - 1).toInt // just enough bins to allow unordered features val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlass() val input = sc.parallelize(arr) - val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5, + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) assert(strategy.isMulticlassClassification) @@ -674,7 +674,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { test("stump with continuous variables for multiclass classification") { val arr = DecisionTreeSuite.generateContinuousDataPointsForMulticlass() val input = sc.parallelize(arr) - val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5, + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, numClassesForClassification = 3) assert(strategy.isMulticlassClassification) @@ -698,7 +698,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { test("stump with continuous + categorical variables for multiclass classification") { val arr = DecisionTreeSuite.generateContinuousDataPointsForMulticlass() val input = sc.parallelize(arr) - val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5, + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 3)) assert(strategy.isMulticlassClassification) @@ -721,7 +721,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { test("stump with categorical variables for ordered multiclass classification") { val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures() val input = sc.parallelize(arr) - val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5, + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 10, 1 -> 10)) assert(strategy.isMulticlassClassification) val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) From 6eed4822759377b241c8dd0adadf32102e01d472 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Thu, 31 Jul 2014 11:39:00 -0700 Subject: [PATCH 09/21] In DecisionTree: Changed from using procedural syntax for functions returning Unit to explicitly writing Unit return type. --- .../org/apache/spark/mllib/tree/DecisionTree.scala | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 7dde7d0116a12..c15ee4f6ba7da 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -606,7 +606,7 @@ object DecisionTree extends Serializable with Logging { agg: Array[Double], nodeIndex: Int, label: Double, - featureIndex: Int) = { + featureIndex: Int): Unit = { // Find the bin index for this feature. val arrShift = 1 + numFeatures * nodeIndex val arrIndex = arrShift + featureIndex @@ -624,7 +624,7 @@ object DecisionTree extends Serializable with Logging { arr: Array[Double], label: Double, agg: Array[Double], - rightChildShift: Int) = { + rightChildShift: Int): Unit = { // Find the bin index for this feature. val arrIndex = 1 + numFeatures * nodeIndex + featureIndex val featureValue = arr(arrIndex).toInt @@ -659,7 +659,7 @@ object DecisionTree extends Serializable with Logging { * @return Array[Double] storing aggregate calculation of size * 2 * numSplits * numFeatures * numNodes for classification */ - def orderedClassificationBinSeqOp(arr: Array[Double], agg: Array[Double]) = { + def orderedClassificationBinSeqOp(arr: Array[Double], agg: Array[Double]): Unit = { // Iterate over all nodes. var nodeIndex = 0 while (nodeIndex < numNodes) { @@ -691,7 +691,7 @@ object DecisionTree extends Serializable with Logging { * @return Array[Double] storing aggregate calculation of size * 2 * numClasses * numSplits * numFeatures * numNodes for classification */ - def unorderedClassificationBinSeqOp(arr: Array[Double], agg: Array[Double]) = { + def unorderedClassificationBinSeqOp(arr: Array[Double], agg: Array[Double]): Unit = { // Iterate over all nodes. var nodeIndex = 0 while (nodeIndex < numNodes) { @@ -736,7 +736,7 @@ object DecisionTree extends Serializable with Logging { * @return Array[Double] storing aggregate calculation of size * 3 * numSplits * numFeatures * numNodes for regression */ - def regressionBinSeqOp(arr: Array[Double], agg: Array[Double]) = { + def regressionBinSeqOp(arr: Array[Double], agg: Array[Double]): Unit = { // Iterate over all nodes. var nodeIndex = 0 while (nodeIndex < numNodes) { From dab0b674b93c7ada8e9d8ac1fc364c0c9438785b Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Thu, 31 Jul 2014 13:08:46 -0700 Subject: [PATCH 10/21] Added documentation for DecisionTree internals --- .../examples/mllib/DecisionTreeRunner.scala | 2 +- .../spark/mllib/tree/DecisionTree.scala | 278 ++++++++++++------ .../mllib/tree/model/DecisionTreeModel.scala | 3 +- 3 files changed, 199 insertions(+), 84 deletions(-) diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala index 5b34a3a41c7f9..9736b624e415d 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala @@ -145,7 +145,7 @@ object DecisionTreeRunner { println(s"Per-class example fractions, counts:") println(s"Class\tFrac\tCount") sortedClasses.foreach { c => { - val frac = classCounts(c) / (0.0 + numExamples) + val frac = classCounts(c) / numExamples.toDouble println(s"$c\t$frac\t${classCounts(c)}") }} (examples, numClasses) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index c15ee4f6ba7da..31574182094cb 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -31,8 +31,8 @@ import org.apache.spark.util.random.XORShiftRandom /** * :: Experimental :: - * A class that implements a decision tree algorithm for classification and regression. It - * supports both continuous and categorical features. + * A class which implements a decision tree learning algorithm for classification and regression. + * It supports both continuous and categorical features. * @param strategy The configuration parameters for the tree algorithm which specify the type * of algorithm (classification, regression, etc.), feature type (continuous, * categorical), depth of the tree, quantile calculation strategy, etc. @@ -42,8 +42,8 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo /** * Method to train a decision tree model over an RDD - * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data - * @return a DecisionTreeModel that can be used for prediction + * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] + * @return DecisionTreeModel which can be used for prediction */ def train(input: RDD[LabeledPoint]): DecisionTreeModel = { @@ -197,17 +197,16 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo object DecisionTree extends Serializable with Logging { /** - * Method to train a decision tree model where the instances are represented as an RDD of - * (label, features) pairs. The method supports binary classification and regression. For the - * binary classification, the label for each instance should either be 0 or 1 to denote the two - * classes. The parameters for the algorithm are specified using the strategy parameter. + * Method to train a decision tree model. + * The method supports binary and multiclass classification and regression. * - * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data - * for DecisionTree + * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * For classification, labels should take values {0, 1, ..., numClasses-1}. + * For regression, labels are real numbers. * @param strategy The configuration parameters for the tree algorithm which specify the type * of algorithm (classification, regression, etc.), feature type (continuous, * categorical), depth of the tree, quantile calculation strategy, etc. - * @return a DecisionTreeModel that can be used for prediction + * @return DecisionTreeModel which can be used for prediction */ def train(input: RDD[LabeledPoint], strategy: Strategy): DecisionTreeModel = { new DecisionTree(strategy).train(input) @@ -219,13 +218,14 @@ object DecisionTree extends Serializable with Logging { * binary classification, the label for each instance should either be 0 or 1 to denote the two * classes. * - * @param input input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as - * training data + * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * For classification, labels should take values {0, 1, ..., numClasses-1}. + * For regression, labels are real numbers. * @param algo algorithm, classification or regression * @param impurity impurity criterion used for information gain calculation * @param maxDepth Maximum depth of the tree. * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. - * @return a DecisionTreeModel that can be used for prediction + * @return DecisionTreeModel which can be used for prediction */ def train( input: RDD[LabeledPoint], @@ -242,14 +242,15 @@ object DecisionTree extends Serializable with Logging { * binary classification, the label for each instance should either be 0 or 1 to denote the two * classes. * - * @param input input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as - * training data + * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * For classification, labels should take values {0, 1, ..., numClasses-1}. + * For regression, labels are real numbers. * @param algo algorithm, classification or regression * @param impurity impurity criterion used for information gain calculation * @param maxDepth Maximum depth of the tree. * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. * @param numClassesForClassification number of classes for classification. Default value of 2. - * @return a DecisionTreeModel that can be used for prediction + * @return DecisionTreeModel which can be used for prediction */ def train( input: RDD[LabeledPoint], @@ -268,8 +269,9 @@ object DecisionTree extends Serializable with Logging { * 1 to denote the two classes. The method also supports categorical features inputs where the * number of categories can specified using the categoricalFeaturesInfo option. * - * @param input input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as - * training data for DecisionTree + * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * For classification, labels should take values {0, 1, ..., numClasses-1}. + * For regression, labels are real numbers. * @param algo classification or regression * @param impurity criterion used for information gain calculation * @param maxDepth Maximum depth of the tree. @@ -282,7 +284,7 @@ object DecisionTree extends Serializable with Logging { * an entry (n -> k) implies the feature n is categorical with k * categories 0, 1, 2, ... , k-1. It's important to note that * features are zero-indexed. - * @return a DecisionTreeModel that can be used for prediction + * @return DecisionTreeModel which can be used for prediction */ def train( input: RDD[LabeledPoint], @@ -304,11 +306,10 @@ object DecisionTree extends Serializable with Logging { * Returns an array of optimal splits for all nodes at a given level. Splits the task into * multiple groups if the level-wise training task could lead to memory overflow. * - * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data - * for DecisionTree + * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] * @param parentImpurities Impurities for all parent nodes for the current level * @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing - * parameters for construction the DecisionTree + * parameters for constructing the DecisionTree * @param level Level of the tree * @param filters Filters for all nodes at a given level * @param splits possible splits for all features @@ -351,11 +352,10 @@ object DecisionTree extends Serializable with Logging { /** * Returns an array of optimal splits for a group of nodes at a given level * - * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data - * for DecisionTree + * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] * @param parentImpurities Impurities for all parent nodes for the current level * @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing - * parameters for construction the DecisionTree + * parameters for constructing the DecisionTree * @param level Level of the tree * @param filters Filters for all nodes at a given level * @param splits possible splits for all features @@ -376,7 +376,7 @@ object DecisionTree extends Serializable with Logging { groupIndex: Int = 0): Array[(Split, InformationGainStats)] = { /* - * The high-level description for the best split optimizations are noted here. + * The high-level descriptions of the best split optimizations are noted here. * * *Level-wise training* * We perform bin calculations for all nodes at the given level to avoid making multiple @@ -399,18 +399,27 @@ object DecisionTree extends Serializable with Logging { * drastically reduce the communication overhead. */ - // common calculations for multiple nested methods + // Common calculations for multiple nested methods: + + // numNodes: Number of nodes in this (level of tree, group), + // where nodes at deeper (larger) levels may be divided into groups. val numNodes = math.pow(2, level).toInt / numGroups logDebug("numNodes = " + numNodes) + // Find the number of features by looking at the first sample. val numFeatures = input.first().features.size logDebug("numFeatures = " + numFeatures) + + // numBins: Number of bins = 1 + number of possible splits val numBins = bins(0).length logDebug("numBins = " + numBins) + val numClasses = strategy.numClassesForClassification logDebug("numClasses = " + numClasses) + val isMulticlassClassification = strategy.isMulticlassClassification logDebug("isMulticlassClassification = " + isMulticlassClassification) + val isMulticlassClassificationWithCategoricalFeatures = strategy.isMulticlassWithCategoricalFeatures logDebug("isMultiClassWithCategoricalFeatures = " + @@ -468,10 +477,13 @@ object DecisionTree extends Serializable with Logging { } /** - * Find bin for one feature. + * Find bin for one (labeledPoint, feature). */ - def findBin(featureIndex: Int, labeledPoint: LabeledPoint, - isFeatureContinuous: Boolean, isSpaceSufficientForAllCategoricalSplits: Boolean): Int = { + def findBin( + featureIndex: Int, + labeledPoint: LabeledPoint, + isFeatureContinuous: Boolean, + isSpaceSufficientForAllCategoricalSplits: Boolean): Int = { val binForFeatures = bins(featureIndex) val feature = labeledPoint.features(featureIndex) @@ -538,7 +550,9 @@ object DecisionTree extends Serializable with Logging { } else { // Perform sequential search to find bin for categorical features. val binIndex = { - if (isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits) { + val isUnorderedFeature = + isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits + if (isUnorderedFeature) { sequentialBinSearchForUnorderedCategoricalFeatureInClassification() } else { sequentialBinSearchForOrderedCategoricalFeatureInClassification() @@ -558,6 +572,14 @@ object DecisionTree extends Serializable with Logging { * where b_ij is an integer between 0 and numBins - 1 for regressions and binary * classification and the categorical feature value in multiclass classification. * Invalid sample is denoted by noting bin for feature 1 as -1. + * + * For unordered features, the "bin index" returned is actually the feature value (category). + * + * @return Array of size 1 + numFeatures * numNodes, where + * arr(0) = label for labeledPoint, and + * arr(1 + numFeatures * nodeIndex + featureIndex) = + * bin index for this labeledPoint + * (or InvalidBinIndex if labeledPoint is not handled by this node) */ def findBinsForLevel(labeledPoint: LabeledPoint): Array[Double] = { // Calculate bin index and label per feature per node. @@ -601,6 +623,15 @@ object DecisionTree extends Serializable with Logging { // Find feature bins for all nodes at a level. val binMappedRDD = input.map(x => findBinsForLevel(x)) + /** + * Increment aggregate in location for (node, feature, bin, label). + * + * @param arr Bin mapping from findBinsForLevel. arr(0) stores the class label. + * Array of size 1 + (numFeatures * numNodes). + * @param agg Array storing aggregate calculation, of size: + * numClasses * numBins * numFeatures * numNodes. + * Indexed by (node, feature, bin, label) where label is the least significant bit. + */ def updateBinForOrderedFeature( arr: Array[Double], agg: Array[Double], @@ -618,6 +649,18 @@ object DecisionTree extends Serializable with Logging { agg(aggIndex + labelInt) = agg(aggIndex + labelInt) + 1 } + /** + * Increment aggregate in location for (nodeIndex, featureIndex, [bins], label), + * where [bins] ranges over all bins. + * Updates left or right side of aggregate depending on split. + * + * @param arr arr(0) = label. + * arr(1 + featureIndex + nodeIndex * numFeatures) = feature value (category) + * @param agg Indexed by (left/right, node, feature, bin, label) + * where label is the least significant bit. + * The left/right specifier is a 0/1 index indicating left/right child info. + * @param rightChildShift Offset for right side of agg. + */ def updateBinForUnorderedFeature( nodeIndex: Int, featureIndex: Int, @@ -649,17 +692,15 @@ object DecisionTree extends Serializable with Logging { } /** - * Performs a sequential aggregation over a partition for classification. For l nodes, - * k features, either the left count or the right count of one of the p bins is - * incremented based upon whether the feature is classified as 0 or 1. + * Helper for binSeqOp. * - * @param agg Array[Double] storing aggregate calculation of size - * numClasses * numSplits * numFeatures*numNodes for classification - * @param arr Array[Double] of size 1 + (numFeatures * numNodes) - * @return Array[Double] storing aggregate calculation of size - * 2 * numSplits * numFeatures * numNodes for classification + * @param arr Bin mapping from findBinsForLevel. arr(0) stores the class label. + * Array of size 1 + (numFeatures * numNodes). + * @param agg Array storing aggregate calculation, of size: + * numClasses * numBins * numFeatures * numNodes. + * Indexed by (node, feature, bin, label) where label is the least significant bit. */ - def orderedClassificationBinSeqOp(arr: Array[Double], agg: Array[Double]): Unit = { + def binaryOrNotCategoricalBinSeqOp(arr: Array[Double], agg: Array[Double]): Unit = { // Iterate over all nodes. var nodeIndex = 0 while (nodeIndex < numNodes) { @@ -681,17 +722,21 @@ object DecisionTree extends Serializable with Logging { } /** - * Performs a sequential aggregation over a partition for classification. For l nodes, - * k features, either the left count or the right count of one of the p bins is - * incremented based upon whether the feature is classified as 0 or 1. + * Helper for binSeqOp. * - * @param agg Array[Double] storing aggregate calculation of size - * numClasses * numSplits * numFeatures*numNodes for classification - * @param arr Array[Double] of size 1 + (numFeatures * numNodes) - * @return Array[Double] storing aggregate calculation of size - * 2 * numClasses * numSplits * numFeatures * numNodes for classification + * @param arr Bin mapping from findBinsForLevel. arr(0) stores the class label. + * Array of size 1 + (numFeatures * numNodes). + * For ordered features, + * arr(1 + featureIndex + nodeIndex * numFeatures) = bin index. + * For unordered features, + * arr(1 + featureIndex + nodeIndex * numFeatures) = feature value (category). + * @param agg Array storing aggregate calculation. + * For ordered features, this is of size: + * numClasses * numBins * numFeatures * numNodes. + * For unordered features, this is of size: + * 2 * numClasses * numBins * numFeatures * numNodes. */ - def unorderedClassificationBinSeqOp(arr: Array[Double], agg: Array[Double]): Unit = { + def multiclassWithCategoricalBinSeqOp(arr: Array[Double], agg: Array[Double]): Unit = { // Iterate over all nodes. var nodeIndex = 0 while (nodeIndex < numNodes) { @@ -727,14 +772,15 @@ object DecisionTree extends Serializable with Logging { } /** - * Performs a sequential aggregation over a partition for regression. For l nodes, k features, + * Performs a sequential aggregation over a partition for regression. + * For l nodes, k features, * the count, sum, sum of squares of one of the p bins is incremented. * - * @param agg Array[Double] storing aggregate calculation of size - * 3 * numSplits * numFeatures * numNodes for classification - * @param arr Array[Double] of size 1 + (numFeatures * numNodes) - * @return Array[Double] storing aggregate calculation of size - * 3 * numSplits * numFeatures * numNodes for regression + * @param agg Array storing aggregate calculation, updated by this function. + * Size: 3 * numBins * numFeatures * numNodes + * @param arr Bin mapping from findBinsForLevel. + * Array of size 1 + (numFeatures * numNodes). + * @return agg */ def regressionBinSeqOp(arr: Array[Double], agg: Array[Double]): Unit = { // Iterate over all nodes. @@ -767,14 +813,30 @@ object DecisionTree extends Serializable with Logging { /** * Performs a sequential aggregation over a partition. + * For l nodes, k features, + * For classification: + * Either the left count or the right count of one of the bins is + * incremented based upon whether the feature is classified as 0 or 1. + * For regression: + * The count, sum, sum of squares of one of the bins is incremented. + * + * @param agg Array storing aggregate calculation, updated by this function. + * Size for classification: + * numClasses * numBins * numFeatures * numNodes for ordered features, or + * 2 * numClasses * numBins * numFeatures * numNodes for unordered features. + * Size for regression: + * 3 * numBins * numFeatures * numNodes. + * @param arr Bin mapping from findBinsForLevel. + * Array of size 1 + (numFeatures * numNodes). + * @return agg */ def binSeqOp(agg: Array[Double], arr: Array[Double]): Array[Double] = { strategy.algo match { case Classification => if(isMulticlassClassificationWithCategoricalFeatures) { - unorderedClassificationBinSeqOp(arr, agg) + multiclassWithCategoricalBinSeqOp(arr, agg) } else { - orderedClassificationBinSeqOp(arr, agg) + binaryOrNotCategoricalBinSeqOp(arr, agg) } case Regression => regressionBinSeqOp(arr, agg) } @@ -937,10 +999,18 @@ object DecisionTree extends Serializable with Logging { /** * Extracts left and right split aggregates. - * @param binData Array[Double] of size 2*numFeatures*numSplits - * @return (leftNodeAgg, rightNodeAgg) tuple of type (Array[Array[Array[Double\]\]\], - * Array[Array[Array[Double\]\]\]) where each array is of size(numFeature, - * (numBins - 1), numClasses) + * @param binData Aggregate array slice from getBinDataForNode. + * For classification: + * For unordered features, this is leftChildData ++ rightChildData, + * each of which is indexed by (feature, split/bin, class), + * with class being the least significant bit. + * For ordered features, this is of size numClasses * numBins * numFeatures. + * For regression: + * This is of size 2 * numFeatures * numBins. + * @return (leftNodeAgg, rightNodeAgg) pair of arrays. + * For classification, each array is of size (numFeatures, (numBins - 1), numClasses). + * For regression, each array is of size (numFeatures, (numBins - 1), 3). + * */ def extractLeftRightNodeAggregates( binData: Array[Double]): (Array[Array[Array[Double]]], Array[Array[Array[Double]]]) = { @@ -983,6 +1053,11 @@ object DecisionTree extends Serializable with Logging { } } + /** + * Reshape binData for this feature. + * Indexes binData as (feature, split, class) with class as the least significant bit. + * @param leftNodeAgg leftNodeAgg(featureIndex)(splitIndex)(classIndex) = aggregate value + */ def findAggForUnorderedFeatureClassification( leftNodeAgg: Array[Array[Array[Double]]], rightNodeAgg: Array[Array[Array[Double]]], @@ -1107,7 +1182,7 @@ object DecisionTree extends Serializable with Logging { /** * Find the best split for a node. - * @param binData Array[Double] of size 2 * numSplits * numFeatures + * @param binData Bin data slice for this node, given by getBinDataForNode. * @param nodeImpurity impurity of the top node * @return tuple of split and information gain */ @@ -1133,7 +1208,7 @@ object DecisionTree extends Serializable with Logging { while (featureIndex < numFeatures) { // Iterate over all splits. var splitIndex = 0 - val maxSplitIndex : Double = { + val maxSplitIndex: Double = { val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty if (isFeatureContinuous) { numBins - 1 @@ -1162,8 +1237,8 @@ object DecisionTree extends Serializable with Logging { (bestFeatureIndex, bestSplitIndex, bestGainStats) } + logDebug("best split = " + splits(bestFeatureIndex)(bestSplitIndex)) logDebug("best split bin = " + bins(bestFeatureIndex)(bestSplitIndex)) - logDebug("best split bin = " + splits(bestFeatureIndex)(bestSplitIndex)) (splits(bestFeatureIndex)(bestSplitIndex), gainStats) } @@ -1214,8 +1289,17 @@ object DecisionTree extends Serializable with Logging { bestSplits } - private def getElementsPerNode(numFeatures: Int, numBins: Int, numClasses: Int, - isMulticlassClassificationWithCategoricalFeatures: Boolean, algo: Algo): Int = { + /** + * Get the number of values to be stored per node in the bin aggregates. + * + * @param numBins Number of bins = 1 + number of possible splits. + */ + private def getElementsPerNode( + numFeatures: Int, + numBins: Int, + numClasses: Int, + isMulticlassClassificationWithCategoricalFeatures: Boolean, + algo: Algo): Int = { algo match { case Classification => if (isMulticlassClassificationWithCategoricalFeatures) { @@ -1228,18 +1312,40 @@ object DecisionTree extends Serializable with Logging { } /** - * Returns split and bins for decision tree calculation. - * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data - * for DecisionTree + * Returns splits and bins for decision tree calculation. + * Continuous and categorical features are handled differently. + * + * Continuous features: + * For each feature, there are numBins - 1 possible splits representing the possible binary + * decisions at each node in the tree. + * + * Categorical features: + * For each feature, there is 1 bin per split. + * Splits and bins are handled in 2 ways: + * (a) For multiclass classification with a low-arity feature + * (i.e., if isMulticlass && isSpaceSufficientForAllCategoricalSplits), + * the feature is split based on subsets of categories. + * There are 2^(maxFeatureValue - 1) - 1 splits. + * (b) For regression and binary classification, + * and for multiclass classification with a high-arity feature, + * there is one split per category. + + * Categorical case (a) features are called unordered features. + * Other cases are called ordered features. + * + * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] * @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing - * parameters for construction the DecisionTree - * @return a tuple of (splits,bins) where splits is an Array of [org.apache.spark.mllib.tree - * .model.Split] of size (numFeatures, numSplits-1) and bins is an Array of [org.apache - * .spark.mllib.tree.model.Bin] of size (numFeatures, numSplits1) + * parameters for construction the DecisionTree + * @return A tuple of (splits,bins). + * Splits is an Array of [[org.apache.spark.mllib.tree.model.Split]] + * of size (numFeatures, numBins - 1). + * Bins is an Array of [[org.apache.spark.mllib.tree.model.Bin]] + * of size (numFeatures, numBins). */ protected[tree] def findSplitsBins( input: RDD[LabeledPoint], strategy: Strategy): (Array[Array[Split]], Array[Array[Bin]]) = { + val count = input.count() // Find the number of features by looking at the first sample @@ -1271,7 +1377,8 @@ object DecisionTree extends Serializable with Logging { logDebug("fraction of data used for calculating quantiles = " + fraction) // sampled input for RDD calculation - val sampledInput = input.sample(false, fraction, new XORShiftRandom().nextInt()).collect() + val sampledInput = + input.sample(withReplacement = false, fraction, new XORShiftRandom().nextInt()).collect() val numSamples = sampledInput.length val stride: Double = numSamples.toDouble / numBins @@ -1306,8 +1413,10 @@ object DecisionTree extends Serializable with Logging { = numBins > math.pow(2, featureCategories.toInt - 1) - 1 // Use different bin/split calculation strategy for categorical features in multiclass - // classification that satisfy the space constraint - if (isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits) { + // classification that satisfy the space constraint. + val isUnorderedFeature = + isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits + if (isUnorderedFeature) { // 2^(maxFeatureValue- 1) - 1 combinations var index = 0 while (index < math.pow(2.0, featureCategories - 1).toInt - 1) { @@ -1332,8 +1441,13 @@ object DecisionTree extends Serializable with Logging { } index += 1 } - } else { - + } else { // ordered feature + /* For a given categorical feature, use a subsample of the data + * to choose how to arrange possible splits. + * This examines each category and computes a centroid. + * These centroids are later used to sort the possible splits. + * centroidForCategories is a mapping: category (for the given feature) --> centroid + */ val centroidForCategories = { if (isMulticlassClassification) { // For categorical variables in multiclass classification, @@ -1343,7 +1457,7 @@ object DecisionTree extends Serializable with Logging { .groupBy(_._1) .mapValues(x => x.groupBy(_._2).mapValues(x => x.size.toDouble)) .map(x => (x._1, x._2.values.toArray)) - .map(x => (x._1, strategy.impurity.calculate(x._2,x._2.sum))) + .map(x => (x._1, strategy.impurity.calculate(x._2, x._2.sum))) } else { // regression or binary classification // For categorical variables in regression and binary classification, // each bin is a category. The bins are sorted and they @@ -1354,7 +1468,7 @@ object DecisionTree extends Serializable with Logging { } } - logDebug("centriod for categories = " + centroidForCategories.mkString(",")) + logDebug("centroid for categories = " + centroidForCategories.mkString(",")) // Check for missing categorical variables and putting them last in the sorted list. val fullCentroidForCategories = scala.collection.mutable.Map[Double,Double]() @@ -1369,7 +1483,7 @@ object DecisionTree extends Serializable with Logging { // bins sorted by centroids val categoriesSortedByCentroid = fullCentroidForCategories.toList.sortBy(_._2) - logDebug("centriod for categorical variable = " + categoriesSortedByCentroid) + logDebug("centroid for categorical variable = " + categoriesSortedByCentroid) var categoriesForSplit = List[Double]() categoriesSortedByCentroid.iterator.zipWithIndex.foreach { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala index d558ab19dfb35..3d3406b5d5f22 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala @@ -24,7 +24,8 @@ import org.apache.spark.mllib.linalg.Vector /** * :: Experimental :: - * Model to store the decision tree parameters + * Decision tree model for classification or regression. + * This model stores the decision tree structure and parameters. * @param topNode root node * @param algo algorithm type -- classification or regression */ From 2b20c6151bab8a2ee218b851f40d54133f9807a2 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Thu, 31 Jul 2014 13:39:43 -0700 Subject: [PATCH 11/21] Small doc and style updates --- .../spark/examples/mllib/DecisionTreeRunner.scala | 10 ++++------ .../org/apache/spark/mllib/tree/DecisionTree.scala | 10 +++++----- 2 files changed, 9 insertions(+), 11 deletions(-) diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala index 9736b624e415d..cf3d2cca81ff6 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala @@ -144,18 +144,16 @@ object DecisionTreeRunner { println(s"numClasses = $numClasses.") println(s"Per-class example fractions, counts:") println(s"Class\tFrac\tCount") - sortedClasses.foreach { c => { + sortedClasses.foreach { c => val frac = classCounts(c) / numExamples.toDouble println(s"$c\t$frac\t${classCounts(c)}") - }} + } (examples, numClasses) } - case Regression => { + case Regression => (origExamples, 0) - } - case _ => { + case _ => throw new IllegalArgumentException("Algo ${params.algo} not supported.") - } } // Split into training, test. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 31574182094cb..7d123dd6ae996 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -43,7 +43,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo /** * Method to train a decision tree model over an RDD * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] - * @return DecisionTreeModel which can be used for prediction + * @return DecisionTreeModel that can be used for prediction */ def train(input: RDD[LabeledPoint]): DecisionTreeModel = { @@ -206,7 +206,7 @@ object DecisionTree extends Serializable with Logging { * @param strategy The configuration parameters for the tree algorithm which specify the type * of algorithm (classification, regression, etc.), feature type (continuous, * categorical), depth of the tree, quantile calculation strategy, etc. - * @return DecisionTreeModel which can be used for prediction + * @return DecisionTreeModel that can be used for prediction */ def train(input: RDD[LabeledPoint], strategy: Strategy): DecisionTreeModel = { new DecisionTree(strategy).train(input) @@ -225,7 +225,7 @@ object DecisionTree extends Serializable with Logging { * @param impurity impurity criterion used for information gain calculation * @param maxDepth Maximum depth of the tree. * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. - * @return DecisionTreeModel which can be used for prediction + * @return DecisionTreeModel that can be used for prediction */ def train( input: RDD[LabeledPoint], @@ -250,7 +250,7 @@ object DecisionTree extends Serializable with Logging { * @param maxDepth Maximum depth of the tree. * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. * @param numClassesForClassification number of classes for classification. Default value of 2. - * @return DecisionTreeModel which can be used for prediction + * @return DecisionTreeModel that can be used for prediction */ def train( input: RDD[LabeledPoint], @@ -284,7 +284,7 @@ object DecisionTree extends Serializable with Logging { * an entry (n -> k) implies the feature n is categorical with k * categories 0, 1, 2, ... , k-1. It's important to note that * features are zero-indexed. - * @return DecisionTreeModel which can be used for prediction + * @return DecisionTreeModel that can be used for prediction */ def train( input: RDD[LabeledPoint], From b8fac571dc4baa58b4c4c1473bb2969553270865 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Thu, 31 Jul 2014 18:56:37 -0700 Subject: [PATCH 12/21] Finished Python DecisionTree API and example but need to test a bit more. --- examples/src/main/python/mllib/tree.py | 54 ++++++++++++++++---------- python/pyspark/mllib/tree.py | 52 ++++++++++++++++++++----- 2 files changed, 77 insertions(+), 29 deletions(-) diff --git a/examples/src/main/python/mllib/tree.py b/examples/src/main/python/mllib/tree.py index f7af896d3e57c..4b5b3bd08b632 100755 --- a/examples/src/main/python/mllib/tree.py +++ b/examples/src/main/python/mllib/tree.py @@ -19,7 +19,7 @@ Decision tree classification and regression using MLlib. """ -import sys +import sys, numpy from operator import add @@ -39,9 +39,23 @@ def parsePoint(line): def getAccuracy(dtModel, data): seqOp = (lambda acc, x: acc + (x[0] == x[1])) trainCorrect = \ - dtModel.predict(data).zip(data.map((lambda p => p.label))).aggregate(0, seqOp, add) + dtModel.predict(data).zip(data.map(lambda p: p.label)).aggregate(0, seqOp, add) return trainCorrect / (0.0 + data.count()) +# Return mean squared error (MSE) of DecisionTreeModel on the given RDD[LabeledPoint]. +def getMSE(dtModel, data): + seqOp = (lambda acc, x: acc + numpy.square(x[0] - x[1])) + trainMSE = \ + dtModel.predict(data).zip(data.map(lambda p: p.label)).aggregate(0, seqOp, add) + return trainMSE / (0.0 + data.count()) + +# Return a new LabeledPoint with the label and feature 0 swapped. +def swapLabelAndFeature0(labeledPoint): + newLabel = labeledPoint.label + newFeatures = labeledPoint.features + (newLabel, newFeatures[0]) = (newFeatures[0], newLabel) + return LabeledPoint(newLabel, newFeatures) + if __name__ == "__main__": if len(sys.argv) != 1: @@ -54,23 +68,23 @@ def getAccuracy(dtModel, data): points = sc.textFile(dataPath).map(parsePoint) # Train a classifier. - model = DecisionTree.trainClassifier(points, numClasses=2) - # Print learned tree. - print "Model numNodes: " + model.numNodes() + "\n" - print "Model depth: " + model.depth() + "\n" - print model - # Check accuracy. - print "Training accuracy: " + getAccuracy(model, points) + "\n" + classificationModel = DecisionTree.trainClassifier(points, numClasses=2) + # Print learned tree and stats. + print "Trained DecisionTree for classification:" + print " Model numNodes: " + classificationModel.numNodes() + "\n" + print " Model depth: " + classificationModel.depth() + "\n" + print " Training accuracy: " + getAccuracy(classificationModel, points) + "\n" + print classificationModel # Switch labels and first feature to create a regression dataset with categorical features. - """ - datasetInfo = DatasetInfo(numClasses=0, numFeatures=numFeatures) - dtParams = DecisionTreeRegressor.defaultParams() - model = DecisionTreeRegressor.train(points, datasetInfo, dtParams) - # Print learned tree. - print "Model numNodes: " + model.numNodes() + "\n" - print "Model depth: " + model.depth() + "\n" - print model - # Check error. - print "Training accuracy: " + getAccuracy(model, points) + "\n" - """ + # Feature 0 is now categorical with 2 categories, and labels are real numbers. + regressionPoints = points.map(lambda labeledPoint: swapLabelAndFeature0(labeledPoint)) + categoricalFeaturesInfo = {0: 2} + regressionModel = \ + DecisionTree.trainRegressor(points, categoricalFeaturesInfo=categoricalFeaturesInfo) + # Print learned tree and stats. + print "Trained DecisionTree for regression:" + print " Model numNodes: " + regressionModel.numNodes() + "\n" + print " Model depth: " + regressionModel.depth() + "\n" + print " Training MSE: " + getMSE(regressionModel, points) + "\n" + print regressionModel diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py index 4eb9d8e739ecf..3f51448add997 100644 --- a/python/pyspark/mllib/tree.py +++ b/python/pyspark/mllib/tree.py @@ -18,18 +18,10 @@ from py4j.java_collections import MapConverter from pyspark import SparkContext, RDD -from pyspark.mllib._common import \ - _convert_vector, \ - _dot, _get_unmangled_rdd, _get_unmangled_double_vector_rdd, \ - _serialize_double_matrix, _deserialize_double_matrix, \ - _serialize_double_vector, _deserialize_double_vector, \ - _deserialize_labeled_point, \ - _get_initial_weights, _serialize_rating, _regression_train_wrapper, \ - _linear_predictor_typecheck, _get_unmangled_labeled_point_rdd -from pyspark.mllib.linalg import SparseVector from pyspark.mllib.regression import LabeledPoint from pyspark.serializers import NoOpSerializer + class DecisionTreeModel(object): """ A decision tree model for classification or regression. @@ -82,6 +74,48 @@ class DecisionTree(object): Learning algorithm for a decision tree model for classification or regression. WARNING: This is an experimental API. It will probably be modified for Spark v1.2. + + Example usage: + >>> from numpy import array, ndarray + >>> from pyspark.mllib.regression import LabeledPoint + >>> from pyspark.mllib.tree import DecisionTree + >>> from pyspark.mllib.linalg import SparseVector + >>> + >>> data = [ + ... LabeledPoint(0.0, [0.0]), + ... LabeledPoint(1.0, [1.0]), + ... LabeledPoint(1.0, [2.0]), + ... LabeledPoint(1.0, [3.0]) + ... ] + >>> + >>> model = DecisionTree.trainClassifier(sc.parallelize(data), numClasses=2) + >>> print(model) + DecisionTreeModel classifier + If (feature 0 <= 0.5) + Predict: 0.0 + Else (feature 0 > 0.5) + Predict: 1.0 + + >>> model.predict(array([1.0])) > 0 + True + >>> model.predict(array([0.0])) == 0 + True + >>> sparse_data = [ + ... LabeledPoint(0.0, SparseVector(2, {0: 0.0})), + ... LabeledPoint(1.0, SparseVector(2, {1: 1.0})), + ... LabeledPoint(0.0, SparseVector(2, {0: 0.0})), + ... LabeledPoint(1.0, SparseVector(2, {1: 2.0})) + ... ] + >>> + >>> model = DecisionTree.trainRegressor(sc.parallelize(sparse_data)) + >>> model.predict(array([0.0, 1.0])) == 1 + True + >>> model.predict(array([0.0, 0.0])) == 0 + True + >>> model.predict(SparseVector(2, {1: 1.0})) == 1 + True + >>> model.predict(SparseVector(2, {1: 0.0})) == 0 + True """ def run(self, data, datasetInfo): From 665ba7822bde3cb8105efb31d22e0084265c92da Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Fri, 1 Aug 2014 09:42:22 -0700 Subject: [PATCH 13/21] Small updates towards Python DecisionTree API --- examples/src/main/python/mllib/tree.py | 12 ++++++------ python/pyspark/mllib/tree.py | 7 +++++-- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/examples/src/main/python/mllib/tree.py b/examples/src/main/python/mllib/tree.py index 4b5b3bd08b632..f15c396ec07a7 100755 --- a/examples/src/main/python/mllib/tree.py +++ b/examples/src/main/python/mllib/tree.py @@ -71,9 +71,9 @@ def swapLabelAndFeature0(labeledPoint): classificationModel = DecisionTree.trainClassifier(points, numClasses=2) # Print learned tree and stats. print "Trained DecisionTree for classification:" - print " Model numNodes: " + classificationModel.numNodes() + "\n" - print " Model depth: " + classificationModel.depth() + "\n" - print " Training accuracy: " + getAccuracy(classificationModel, points) + "\n" + print " Model numNodes: %d\n" % classificationModel.numNodes() + print " Model depth: %d\n" % classificationModel.depth() + print " Training accuracy: %g\n" % getAccuracy(classificationModel, points) print classificationModel # Switch labels and first feature to create a regression dataset with categorical features. @@ -84,7 +84,7 @@ def swapLabelAndFeature0(labeledPoint): DecisionTree.trainRegressor(points, categoricalFeaturesInfo=categoricalFeaturesInfo) # Print learned tree and stats. print "Trained DecisionTree for regression:" - print " Model numNodes: " + regressionModel.numNodes() + "\n" - print " Model depth: " + regressionModel.depth() + "\n" - print " Training MSE: " + getMSE(regressionModel, points) + "\n" + print " Model numNodes: %d\n" % regressionModel.numNodes() + print " Model depth: %d\n" % regressionModel.depth() + print " Training MSE: %g\n" % getMSE(regressionModel, points) print regressionModel diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py index 3f51448add997..2818078300394 100644 --- a/python/pyspark/mllib/tree.py +++ b/python/pyspark/mllib/tree.py @@ -18,10 +18,12 @@ from py4j.java_collections import MapConverter from pyspark import SparkContext, RDD +from pyspark.mllib._common import \ + _get_unmangled_double_vector_rdd, _serialize_double_vector, \ + _deserialize_labeled_point, _get_unmangled_labeled_point_rdd from pyspark.mllib.regression import LabeledPoint from pyspark.serializers import NoOpSerializer - class DecisionTreeModel(object): """ A decision tree model for classification or regression. @@ -45,7 +47,8 @@ def predict(self, x): :param x: Either one data point (feature vector), or a dataset (RDD of feature vectors) """ pythonAPI = self._sc._jvm.PythonMLLibAPI() - if type(x) == RDD: + print "predict called for type: " + str(type(x)) + if isinstance(x, RDD): # Bulk prediction dataBytes = _get_unmangled_double_vector_rdd(x) jSerializedPreds = pythonAPI.predictDecisionTreeModel(self._java_model, dataBytes._jrdd) From 93953f16e16e4605cbfe8a9e3a26b372e69707ae Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Fri, 1 Aug 2014 14:34:54 -0700 Subject: [PATCH 14/21] Likely done with Python API. --- examples/src/main/python/mllib/tree.py | 14 +++++++------ .../mllib/api/python/PythonMLLibAPI.scala | 2 +- python/pyspark/mllib/_common.py | 2 +- python/pyspark/mllib/tree.py | 20 ++++++++++++++----- 4 files changed, 25 insertions(+), 13 deletions(-) diff --git a/examples/src/main/python/mllib/tree.py b/examples/src/main/python/mllib/tree.py index f15c396ec07a7..cbab84fa1477f 100755 --- a/examples/src/main/python/mllib/tree.py +++ b/examples/src/main/python/mllib/tree.py @@ -38,15 +38,17 @@ def parsePoint(line): # Return accuracy of DecisionTreeModel on the given RDD[LabeledPoint]. def getAccuracy(dtModel, data): seqOp = (lambda acc, x: acc + (x[0] == x[1])) - trainCorrect = \ - dtModel.predict(data).zip(data.map(lambda p: p.label)).aggregate(0, seqOp, add) + predictions = dtModel.predict(data) + truth = data.map(lambda p: p.label) + trainCorrect = predictions.zip(truth).aggregate(0, seqOp, add) return trainCorrect / (0.0 + data.count()) # Return mean squared error (MSE) of DecisionTreeModel on the given RDD[LabeledPoint]. def getMSE(dtModel, data): seqOp = (lambda acc, x: acc + numpy.square(x[0] - x[1])) - trainMSE = \ - dtModel.predict(data).zip(data.map(lambda p: p.label)).aggregate(0, seqOp, add) + predictions = dtModel.predict(data) + truth = data.map(lambda p: p.label) + trainMSE = predictions.zip(truth).aggregate(0, seqOp, add) return trainMSE / (0.0 + data.count()) # Return a new LabeledPoint with the label and feature 0 swapped. @@ -81,10 +83,10 @@ def swapLabelAndFeature0(labeledPoint): regressionPoints = points.map(lambda labeledPoint: swapLabelAndFeature0(labeledPoint)) categoricalFeaturesInfo = {0: 2} regressionModel = \ - DecisionTree.trainRegressor(points, categoricalFeaturesInfo=categoricalFeaturesInfo) + DecisionTree.trainRegressor(regressionPoints, categoricalFeaturesInfo=categoricalFeaturesInfo) # Print learned tree and stats. print "Trained DecisionTree for regression:" print " Model numNodes: %d\n" % regressionModel.numNodes() print " Model depth: %d\n" % regressionModel.depth() - print " Training MSE: %g\n" % getMSE(regressionModel, points) + print " Training MSE: %g\n" % getMSE(regressionModel, regressionPoints) print regressionModel diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 0a0a296a73c57..7d579eb586d66 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -530,7 +530,7 @@ class PythonMLLibAPI extends Serializable { model: DecisionTreeModel, dataJRDD: JavaRDD[Array[Byte]]): JavaRDD[Array[Byte]] = { val data = dataJRDD.rdd.map(xBytes => deserializeDoubleVector(xBytes)) - model.predict(data).map(Utils.serialize(_)) + model.predict(data).map(serializeDouble) } // Used by the *RDD methods to get default seed if not passed in from pyspark diff --git a/python/pyspark/mllib/_common.py b/python/pyspark/mllib/_common.py index 8e3ad6b783b6c..2d5b8d535b17b 100644 --- a/python/pyspark/mllib/_common.py +++ b/python/pyspark/mllib/_common.py @@ -376,7 +376,7 @@ def _linear_predictor_typecheck(x, coeffs): if x.size != coeffs.shape[0]: raise RuntimeError("Got sparse vector of size %d; wanted %d" % ( x.size, coeffs.shape[0])) - elif (type(x) == RDD): + elif isinstance(x, RDD): raise RuntimeError("Bulk predict not yet supported.") else: raise TypeError("Argument of type " + type(x).__name__ + " unsupported") diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py index 2818078300394..9e7b9721a63b9 100644 --- a/python/pyspark/mllib/tree.py +++ b/python/pyspark/mllib/tree.py @@ -19,8 +19,9 @@ from pyspark import SparkContext, RDD from pyspark.mllib._common import \ - _get_unmangled_double_vector_rdd, _serialize_double_vector, \ - _deserialize_labeled_point, _get_unmangled_labeled_point_rdd + _get_unmangled_rdd, _get_unmangled_double_vector_rdd, _serialize_double_vector, \ + _deserialize_labeled_point, _get_unmangled_labeled_point_rdd, \ + _deserialize_double from pyspark.mllib.regression import LabeledPoint from pyspark.serializers import NoOpSerializer @@ -44,16 +45,24 @@ def __del__(self): def predict(self, x): """ - :param x: Either one data point (feature vector), or a dataset (RDD of feature vectors) + Predict the label of one or more examples. + NOTE: This currently does NOT support batch prediction. + + :param x: Data point: feature vector, or a LabeledPoint (whose label is ignored). """ pythonAPI = self._sc._jvm.PythonMLLibAPI() - print "predict called for type: " + str(type(x)) if isinstance(x, RDD): # Bulk prediction + if x.count() == 0: + raise RuntimeError("DecisionTreeModel.predict(x) given empty RDD x.") + elementType = type(x.take(1)[0]) + if elementType == LabeledPoint: + x = x.map(lambda x: x.features) dataBytes = _get_unmangled_double_vector_rdd(x) jSerializedPreds = pythonAPI.predictDecisionTreeModel(self._java_model, dataBytes._jrdd) + dataBytes.unpersist() serializedPreds = RDD(jSerializedPreds, self._sc, NoOpSerializer()) - return serializedPreds.map(lambda bytes: _deserialize_labeled_point(bytearray(bytes))) + return serializedPreds.map(lambda bytes: _deserialize_double(bytearray(bytes))) else: if type(x) == LabeledPoint: x_ = _serialize_double_vector(x.features) @@ -202,6 +211,7 @@ def train(data, algo, numClasses, categoricalFeaturesInfo, impurity, maxDepth, m dataBytes._jrdd, algo, numClasses, categoricalFeaturesInfoJMap, impurity, maxDepth, maxBins) + dataBytes.unpersist() return DecisionTreeModel(sc, model) From 225822fe38762596b8c917a867e5cdbb2d9b4b55 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Fri, 1 Aug 2014 14:50:42 -0700 Subject: [PATCH 15/21] Bug: In DecisionTree, the method sequentialBinSearchForOrderedCategoricalFeatureInClassification() indexed bins from 0 to (math.pow(2, featureCategories.toInt - 1) - 1). This upper bound is the bound for unordered categorical features, not ordered ones. The upper bound should be the arity (i.e., max value) of the feature. Added new test to DecisionTreeSuite to catch this: "regression stump with categorical variables of arity 2" Bug fix: Modified upper bound discussed above. Also: Small improvements to coding style in DecisionTree. --- .../spark/mllib/tree/DecisionTree.scala | 45 +++++++++++-------- .../spark/mllib/tree/DecisionTreeSuite.scala | 29 ++++++++++++ 2 files changed, 56 insertions(+), 18 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 7d123dd6ae996..382e76a9b7cba 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -498,7 +498,7 @@ object DecisionTree extends Serializable with Logging { val bin = binForFeatures(mid) val lowThreshold = bin.lowSplit.threshold val highThreshold = bin.highSplit.threshold - if ((lowThreshold < feature) && (highThreshold >= feature)){ + if ((lowThreshold < feature) && (highThreshold >= feature)) { return mid } else if (lowThreshold >= feature) { @@ -522,28 +522,36 @@ object DecisionTree extends Serializable with Logging { } /** - * Sequential search helper method to find bin for categorical feature. + * Sequential search helper method to find bin for categorical feature + * (for classification and regression). */ - def sequentialBinSearchForOrderedCategoricalFeatureInClassification(): Int = { + def sequentialBinSearchForOrderedCategoricalFeature(): Int = { val featureCategories = strategy.categoricalFeaturesInfo(featureIndex) - val numCategoricalBins = math.pow(2.0, featureCategories - 1).toInt - 1 + val featureValue = labeledPoint.features(featureIndex) var binIndex = 0 - while (binIndex < numCategoricalBins) { + while (binIndex < featureCategories) { val bin = bins(featureIndex)(binIndex) val categories = bin.highSplit.categories - val features = labeledPoint.features - if (categories.contains(features(featureIndex))) { + if (categories.contains(featureValue)) { return binIndex } binIndex += 1 } + if (featureValue < 0 || featureValue >= featureCategories) { + throw new IllegalArgumentException( + s"DecisionTree given invalid data:" + + s" Feature $featureIndex is categorical with values in" + + s" {0,...,${featureCategories - 1}," + + s" but a data point gives it value $featureValue.\n" + + " Bad data point: " + labeledPoint.toString) + } -1 } if (isFeatureContinuous) { // Perform binary search for finding bin for continuous features. val binIndex = binarySearchForBins() - if (binIndex == -1){ + if (binIndex == -1) { throw new UnknownError("no bin was found for continuous variable.") } binIndex @@ -555,10 +563,10 @@ object DecisionTree extends Serializable with Logging { if (isUnorderedFeature) { sequentialBinSearchForUnorderedCategoricalFeatureInClassification() } else { - sequentialBinSearchForOrderedCategoricalFeatureInClassification() + sequentialBinSearchForOrderedCategoricalFeature() } } - if (binIndex == -1){ + if (binIndex == -1) { throw new UnknownError("no bin was found for categorical variable.") } binIndex @@ -642,11 +650,12 @@ object DecisionTree extends Serializable with Logging { val arrShift = 1 + numFeatures * nodeIndex val arrIndex = arrShift + featureIndex // Update the left or right count for one bin. - val aggShift = numClasses * numBins * numFeatures * nodeIndex - val aggIndex - = aggShift + numClasses * featureIndex * numBins + arr(arrIndex).toInt * numClasses - val labelInt = label.toInt - agg(aggIndex + labelInt) = agg(aggIndex + labelInt) + 1 + val aggIndex = + numClasses * numBins * numFeatures * nodeIndex + + numClasses * numBins * featureIndex + + numClasses * arr(arrIndex).toInt + + label.toInt + agg(aggIndex) += 1 } /** @@ -1127,7 +1136,7 @@ object DecisionTree extends Serializable with Logging { val rightNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, numClasses) var featureIndex = 0 while (featureIndex < numFeatures) { - if (isMulticlassClassificationWithCategoricalFeatures){ + if (isMulticlassClassificationWithCategoricalFeatures) { val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty if (isFeatureContinuous) { findAggForOrderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex) @@ -1393,7 +1402,7 @@ object DecisionTree extends Serializable with Logging { // Iterate over all features. var featureIndex = 0 - while (featureIndex < numFeatures){ + while (featureIndex < numFeatures) { // Check whether the feature is continuous. val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty if (isFeatureContinuous) { @@ -1513,7 +1522,7 @@ object DecisionTree extends Serializable with Logging { if (isFeatureContinuous) { // Bins for categorical variables are already assigned. bins(featureIndex)(0) = new Bin(new DummyLowSplit(featureIndex, Continuous), splits(featureIndex)(0), Continuous, Double.MinValue) - for (index <- 1 until numBins - 1){ + for (index <- 1 until numBins - 1) { val bin = new Bin(splits(featureIndex)(index-1), splits(featureIndex)(index), Continuous, Double.MinValue) bins(featureIndex)(index) = bin diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 10462db700628..546a132559326 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -42,6 +42,18 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(accuracy >= requiredAccuracy) } + def validateRegressor( + model: DecisionTreeModel, + input: Seq[LabeledPoint], + requiredMSE: Double) { + val predictions = input.map(x => model.predict(x.features)) + val squaredError = predictions.zip(input).map { case (prediction, expected) => + (prediction - expected.label) * (prediction - expected.label) + }.sum + val mse = squaredError / input.length + assert(mse <= requiredMSE) + } + test("split and bin calculation") { val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1() assert(arr.length === 1000) @@ -454,6 +466,23 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(stats.impurity > 0.2) } + test("regression stump with categorical variables of arity 2") { + val arr = DecisionTreeSuite.generateCategoricalDataPoints() + assert(arr.length === 1000) + val rdd = sc.parallelize(arr) + val strategy = new Strategy( + Regression, + Variance, + maxDepth = 2, + maxBins = 100, + categoricalFeaturesInfo = Map(0 -> 2, 1-> 2)) + + val model = DecisionTree.train(rdd, strategy) + validateRegressor(model, arr, 0.0) + assert(model.numNodes === 3) + assert(model.depth === 1) + } + test("stump with fixed label 0 for Gini") { val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel0() assert(arr.length === 1000) From 4801b40a704037c124a3cf30b02b2dd31ab1a785 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Fri, 1 Aug 2014 15:40:12 -0700 Subject: [PATCH 16/21] Small style update to DecisionTreeSuite --- .../scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 546a132559326..8665a00f3b356 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -48,7 +48,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { requiredMSE: Double) { val predictions = input.map(x => model.predict(x.features)) val squaredError = predictions.zip(input).map { case (prediction, expected) => - (prediction - expected.label) * (prediction - expected.label) + val err = prediction - expected.label + err * err }.sum val mse = squaredError / input.length assert(mse <= requiredMSE) From 7968692e9d1c211f283554d7d23fef1154e3b579 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Fri, 1 Aug 2014 16:01:52 -0700 Subject: [PATCH 17/21] small braces typo fix --- .../scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index fcec058400cdb..8ceaa166089e6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -533,6 +533,7 @@ class PythonMLLibAPI extends Serializable { dataJRDD: JavaRDD[Array[Byte]]): JavaRDD[Array[Byte]] = { val data = dataJRDD.rdd.map(xBytes => deserializeDoubleVector(xBytes)) model.predict(data).map(serializeDouble) + } /** * Java stub for mllib Statistics.corr(X: RDD[Vector], method: String). From bf21be43303d269f7edb8aeeaa4f700b7dc0d815 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Fri, 1 Aug 2014 16:58:01 -0700 Subject: [PATCH 18/21] removed old run() func from DecisionTree --- python/pyspark/mllib/tree.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py index 9e7b9721a63b9..08d974deaf22e 100644 --- a/python/pyspark/mllib/tree.py +++ b/python/pyspark/mllib/tree.py @@ -130,15 +130,6 @@ class DecisionTree(object): True """ - def run(self, data, datasetInfo): - """ - :param data: RDD of NumPy vectors, one per element, where the first - coordinate is the label and the rest is the feature vector. - Labels are integers {0,1,...,numClasses}. - :param datasetInfo: Dataset metadata - :return: DecisionTreeClassifierModel - """ - @staticmethod def trainClassifier(data, numClasses, categoricalFeaturesInfo={}, impurity="gini", maxDepth=4, maxBins=100): From cf46ad7637ba37081038c5c508a0a46e82689ec6 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Fri, 1 Aug 2014 23:01:30 -0700 Subject: [PATCH 19/21] Python DecisionTreeModel * predict(empty RDD) returns an empty RDD instead of an error. * Removed support for calling predict() on LabeledPoint and RDD[LabeledPoint] * predict() does not cache serialized RDD any more. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit PythonMLLibAPI.scala * Using JavaConverters instead of JavaConversions pyspark/mllib/_common.py * Updated _get_unmangled_*rdd() methods to take cache option defaulting to True (original behavior). Testing and examples: * Moved unit tests to pyspark/mllib/tests.py * Updated examples/…/tree.py ** Takes a libsvm dataset filepath as an optional argument. ** Re-indexes classes as needed to handle more libsvm datasets. Other stuff: * Removed bad “@deprecated” tag from python/util.py and replaced with warnings.warn(...) * Added python/mllib.util.py to python/run-tests script. * Small doc and style updates --- .../main/python/mllib/logistic_regression.py | 4 +- examples/src/main/python/mllib/tree.py | 119 ++++++++++++------ .../mllib/api/python/PythonMLLibAPI.scala | 6 +- .../mllib/tree/configuration/Strategy.scala | 3 +- python/pyspark/mllib/_common.py | 31 +++-- python/pyspark/mllib/tests.py | 38 ++++++ python/pyspark/mllib/tree.py | 80 ++++++------ python/pyspark/mllib/util.py | 5 +- python/run-tests | 1 + 9 files changed, 193 insertions(+), 94 deletions(-) diff --git a/examples/src/main/python/mllib/logistic_regression.py b/examples/src/main/python/mllib/logistic_regression.py index 6e0f7a4ee5a81..9d547ff77c984 100755 --- a/examples/src/main/python/mllib/logistic_regression.py +++ b/examples/src/main/python/mllib/logistic_regression.py @@ -30,8 +30,10 @@ from pyspark.mllib.classification import LogisticRegressionWithSGD -# Parse a line of text into an MLlib LabeledPoint object def parsePoint(line): + """ + Parse a line of text into an MLlib LabeledPoint object. + """ values = [float(s) for s in line.split(' ')] if values[0] == -1: # Convert -1 labels to 0 for MLlib values[0] = 0 diff --git a/examples/src/main/python/mllib/tree.py b/examples/src/main/python/mllib/tree.py index cbab84fa1477f..e415368e5bd9f 100755 --- a/examples/src/main/python/mllib/tree.py +++ b/examples/src/main/python/mllib/tree.py @@ -19,74 +19,111 @@ Decision tree classification and regression using MLlib. """ -import sys, numpy +import numpy, os, sys from operator import add from pyspark import SparkContext from pyspark.mllib.regression import LabeledPoint from pyspark.mllib.tree import DecisionTree +from pyspark.mllib.util import MLUtils -# Parse a line of text into an MLlib LabeledPoint object -def parsePoint(line): - values = [float(s) for s in line.split(',')] - if values[0] == -1: # Convert -1 labels to 0 for MLlib - values[0] = 0 - return LabeledPoint(values[0], values[1:]) - -# Return accuracy of DecisionTreeModel on the given RDD[LabeledPoint]. def getAccuracy(dtModel, data): + """ + Return accuracy of DecisionTreeModel on the given RDD[LabeledPoint]. + """ seqOp = (lambda acc, x: acc + (x[0] == x[1])) - predictions = dtModel.predict(data) + predictions = dtModel.predict(data.map(lambda x: x.features)) truth = data.map(lambda p: p.label) trainCorrect = predictions.zip(truth).aggregate(0, seqOp, add) return trainCorrect / (0.0 + data.count()) -# Return mean squared error (MSE) of DecisionTreeModel on the given RDD[LabeledPoint]. + def getMSE(dtModel, data): + """ + Return mean squared error (MSE) of DecisionTreeModel on the given + RDD[LabeledPoint]. + """ seqOp = (lambda acc, x: acc + numpy.square(x[0] - x[1])) - predictions = dtModel.predict(data) + predictions = dtModel.predict(data.map(lambda x: x.features)) truth = data.map(lambda p: p.label) trainMSE = predictions.zip(truth).aggregate(0, seqOp, add) return trainMSE / (0.0 + data.count()) -# Return a new LabeledPoint with the label and feature 0 swapped. -def swapLabelAndFeature0(labeledPoint): - newLabel = labeledPoint.label - newFeatures = labeledPoint.features - (newLabel, newFeatures[0]) = (newFeatures[0], newLabel) - return LabeledPoint(newLabel, newFeatures) + +def reindexClassLabels(data): + """ + Re-index class labels in a dataset to the range {0,...,numClasses-1}. + If all labels in that range already appear at least once, + then the returned RDD is the same one (without a mapping). + Note: If a label simply does not appear in the data, + the index will not include it. + Be aware of this when reindexing subsampled data. + :param data: RDD of LabeledPoint where labels are integer values + denoting labels for a classification problem. + :return: Pair (reindexedData, origToNewLabels) where + reindexedData is an RDD of LabeledPoint with labels in + the range {0,...,numClasses-1}, and + origToNewLabels is a dictionary mapping original labels + to new labels. + """ + # classCounts: class --> # examples in class + classCounts = data.map(lambda x: x.label).countByValue() + numExamples = sum(classCounts.values()) + sortedClasses = sorted(classCounts.keys()) + numClasses = len(classCounts) + # origToNewLabels: class --> index in 0,...,numClasses-1 + if (numClasses < 2): + print >> sys.stderr, \ + "Dataset for classification should have at least 2 classes." + \ + " The given dataset had only %d classes." % numClasses + exit(-1) + origToNewLabels = dict([(sortedClasses[i], i) for i in range(0,numClasses)]) + + print "numClasses = %d" % numClasses + print "Per-class example fractions, counts:" + print "Class\tFrac\tCount" + for c in sortedClasses: + frac = classCounts[c] / (numExamples + 0.0) + print "%g\t%g\t%d" % (c, frac, classCounts[c]) + + if (sortedClasses[0] == 0 and sortedClasses[-1] == numClasses - 1): + return (data, origToNewLabels) + else: + reindexedData = \ + data.map(lambda x: LabeledPoint(origToNewLabels[x.label], x.features)) + return (reindexedData, origToNewLabels) + + +def usage(): + print >> sys.stderr, \ + "Usage: logistic_regression [libsvm format data filepath]\n" + \ + " Note: This only supports binary classification." + exit(-1) if __name__ == "__main__": - if len(sys.argv) != 1: - print >> sys.stderr, "Usage: logistic_regression" - exit(-1) + if len(sys.argv) > 2: + usage() sc = SparkContext(appName="PythonDT") # Load data. - dataPath = 'data/mllib/sample_tree_data.csv' - points = sc.textFile(dataPath).map(parsePoint) + dataPath = 'data/mllib/sample_libsvm_data.txt' + if len(sys.argv) == 2: + dataPath = sys.argv[1] + if not os.path.isfile(dataPath): + usage() + points = MLUtils.loadLibSVMFile(sc, dataPath) + + # Re-index class labels if needed. + (reindexedData, origToNewLabels) = reindexClassLabels(points) # Train a classifier. - classificationModel = DecisionTree.trainClassifier(points, numClasses=2) + model = DecisionTree.trainClassifier(reindexedData, numClasses=2) # Print learned tree and stats. print "Trained DecisionTree for classification:" - print " Model numNodes: %d\n" % classificationModel.numNodes() - print " Model depth: %d\n" % classificationModel.depth() - print " Training accuracy: %g\n" % getAccuracy(classificationModel, points) - print classificationModel - - # Switch labels and first feature to create a regression dataset with categorical features. - # Feature 0 is now categorical with 2 categories, and labels are real numbers. - regressionPoints = points.map(lambda labeledPoint: swapLabelAndFeature0(labeledPoint)) - categoricalFeaturesInfo = {0: 2} - regressionModel = \ - DecisionTree.trainRegressor(regressionPoints, categoricalFeaturesInfo=categoricalFeaturesInfo) - # Print learned tree and stats. - print "Trained DecisionTree for regression:" - print " Model numNodes: %d\n" % regressionModel.numNodes() - print " Model depth: %d\n" % regressionModel.depth() - print " Training MSE: %g\n" % getMSE(regressionModel, regressionPoints) - print regressionModel + print " Model numNodes: %d\n" % model.numNodes() + print " Model depth: %d\n" % model.depth() + print " Training accuracy: %g\n" % getAccuracy(model, reindexedData) + print model diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 8ceaa166089e6..0fb0144351107 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -19,7 +19,7 @@ package org.apache.spark.mllib.api.python import java.nio.{ByteBuffer, ByteOrder} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} @@ -477,7 +477,7 @@ class PythonMLLibAPI extends Serializable { dataBytesJRDD: JavaRDD[Array[Byte]], algoStr: String, numClasses: Int, - categoricalFeaturesInfoJMap: java.util.Map[Int,Int], + categoricalFeaturesInfoJMap: java.util.Map[Int, Int], impurityStr: String, maxDepth: Int, maxBins: Int): DecisionTreeModel = { @@ -502,7 +502,7 @@ class PythonMLLibAPI extends Serializable { maxDepth = maxDepth, numClassesForClassification = numClasses, maxBins = maxBins, - categoricalFeaturesInfo = categoricalFeaturesInfoJMap.toMap) + categoricalFeaturesInfo = categoricalFeaturesInfoJMap.asScala.toMap) DecisionTree.train(data, strategy) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index 5c65b537b6867..fdad4f029aa99 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -56,7 +56,8 @@ class Strategy ( if (algo == Classification) { require(numClassesForClassification >= 2) } - val isMulticlassClassification = numClassesForClassification > 2 + val isMulticlassClassification = + algo == Classification && numClassesForClassification > 2 val isMulticlassWithCategoricalFeatures = isMulticlassClassification && (categoricalFeaturesInfo.size > 0) diff --git a/python/pyspark/mllib/_common.py b/python/pyspark/mllib/_common.py index 7664511fef01b..9c1565affbdac 100644 --- a/python/pyspark/mllib/_common.py +++ b/python/pyspark/mllib/_common.py @@ -343,22 +343,35 @@ def _copyto(array, buffer, offset, shape, dtype): temp_array[...] = array -def _get_unmangled_rdd(data, serializer): +def _get_unmangled_rdd(data, serializer, cache=True): + """ + :param cache: If True, the serialized RDD is cached. (default = True) + WARNING: Users should unpersist() this later! + """ dataBytes = data.map(serializer) dataBytes._bypass_serializer = True - dataBytes.cache() # TODO: users should unpersist() this later! + if cache: + dataBytes.cache() return dataBytes -# Map a pickled Python RDD of Python dense or sparse vectors to a Java RDD of -# _serialized_double_vectors -def _get_unmangled_double_vector_rdd(data): - return _get_unmangled_rdd(data, _serialize_double_vector) +def _get_unmangled_double_vector_rdd(data, cache=True): + """ + Map a pickled Python RDD of Python dense or sparse vectors to a Java RDD of + _serialized_double_vectors. + :param cache: If True, the serialized RDD is cached. (default = True) + WARNING: Users should unpersist() this later! + """ + return _get_unmangled_rdd(data, _serialize_double_vector, cache) -# Map a pickled Python RDD of LabeledPoint to a Java RDD of _serialized_labeled_points -def _get_unmangled_labeled_point_rdd(data): - return _get_unmangled_rdd(data, _serialize_labeled_point) +def _get_unmangled_labeled_point_rdd(data, cache=True): + """ + Map a pickled Python RDD of LabeledPoint to a Java RDD of _serialized_labeled_points. + :param cache: If True, the serialized RDD is cached. (default = True) + WARNING: Users should unpersist() this later! + """ + return _get_unmangled_rdd(data, _serialize_labeled_point, cache) # Common functions for dealing with and training linear models diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 37ccf1d590743..fedfe4fb71f8b 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -100,6 +100,7 @@ def test_clustering(self): def test_classification(self): from pyspark.mllib.classification import LogisticRegressionWithSGD, SVMWithSGD, NaiveBayes + from pyspark.mllib.tree import DecisionTree data = [ LabeledPoint(0.0, [1, 0, 0]), LabeledPoint(1.0, [0, 1, 1]), @@ -127,9 +128,19 @@ def test_classification(self): self.assertTrue(nb_model.predict(features[2]) <= 0) self.assertTrue(nb_model.predict(features[3]) > 0) + categoricalFeaturesInfo = {0: 3} # feature 0 has 3 categories + dt_model = \ + DecisionTree.trainClassifier(rdd, numClasses=2, + categoricalFeaturesInfo=categoricalFeaturesInfo) + self.assertTrue(dt_model.predict(features[0]) <= 0) + self.assertTrue(dt_model.predict(features[1]) > 0) + self.assertTrue(dt_model.predict(features[2]) <= 0) + self.assertTrue(dt_model.predict(features[3]) > 0) + def test_regression(self): from pyspark.mllib.regression import LinearRegressionWithSGD, LassoWithSGD, \ RidgeRegressionWithSGD + from pyspark.mllib.tree import DecisionTree data = [ LabeledPoint(-1.0, [0, -1]), LabeledPoint(1.0, [0, 1]), @@ -157,6 +168,14 @@ def test_regression(self): self.assertTrue(rr_model.predict(features[2]) <= 0) self.assertTrue(rr_model.predict(features[3]) > 0) + categoricalFeaturesInfo = {0: 2} # feature 0 has 2 categories + dt_model = \ + DecisionTree.trainRegressor(rdd, categoricalFeaturesInfo=categoricalFeaturesInfo) + self.assertTrue(dt_model.predict(features[0]) <= 0) + self.assertTrue(dt_model.predict(features[1]) > 0) + self.assertTrue(dt_model.predict(features[2]) <= 0) + self.assertTrue(dt_model.predict(features[3]) > 0) + @unittest.skipIf(not _have_scipy, "SciPy not installed") class SciPyTests(PySparkTestCase): @@ -229,6 +248,7 @@ def test_clustering(self): def test_classification(self): from pyspark.mllib.classification import LogisticRegressionWithSGD, SVMWithSGD, NaiveBayes + from pyspark.mllib.tree import DecisionTree data = [ LabeledPoint(0.0, self.scipy_matrix(2, {0: 1.0})), LabeledPoint(1.0, self.scipy_matrix(2, {1: 1.0})), @@ -256,9 +276,19 @@ def test_classification(self): self.assertTrue(nb_model.predict(features[2]) <= 0) self.assertTrue(nb_model.predict(features[3]) > 0) + categoricalFeaturesInfo = {0: 3} # feature 0 has 3 categories + dt_model = \ + DecisionTree.trainClassifier(rdd, numClasses=2, + categoricalFeaturesInfo=categoricalFeaturesInfo) + self.assertTrue(dt_model.predict(features[0]) <= 0) + self.assertTrue(dt_model.predict(features[1]) > 0) + self.assertTrue(dt_model.predict(features[2]) <= 0) + self.assertTrue(dt_model.predict(features[3]) > 0) + def test_regression(self): from pyspark.mllib.regression import LinearRegressionWithSGD, LassoWithSGD, \ RidgeRegressionWithSGD + from pyspark.mllib.tree import DecisionTree data = [ LabeledPoint(-1.0, self.scipy_matrix(2, {1: -1.0})), LabeledPoint(1.0, self.scipy_matrix(2, {1: 1.0})), @@ -286,6 +316,14 @@ def test_regression(self): self.assertTrue(rr_model.predict(features[2]) <= 0) self.assertTrue(rr_model.predict(features[3]) > 0) + categoricalFeaturesInfo = {0: 2} # feature 0 has 2 categories + dt_model = \ + DecisionTree.trainRegressor(rdd, categoricalFeaturesInfo=categoricalFeaturesInfo) + self.assertTrue(dt_model.predict(features[0]) <= 0) + self.assertTrue(dt_model.predict(features[1]) > 0) + self.assertTrue(dt_model.predict(features[2]) <= 0) + self.assertTrue(dt_model.predict(features[3]) > 0) + if __name__ == "__main__": if not _have_scipy: diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py index 08d974deaf22e..1e0006df75ac6 100644 --- a/python/pyspark/mllib/tree.py +++ b/python/pyspark/mllib/tree.py @@ -29,7 +29,8 @@ class DecisionTreeModel(object): """ A decision tree model for classification or regression. - WARNING: This is an experimental API. It will probably be modified for Spark v1.2. + EXPERIMENTAL: This is an experimental API. + It will probably be modified for Spark v1.2. """ def __init__(self, sc, java_model): @@ -46,29 +47,23 @@ def __del__(self): def predict(self, x): """ Predict the label of one or more examples. - NOTE: This currently does NOT support batch prediction. - - :param x: Data point: feature vector, or a LabeledPoint (whose label is ignored). + :param x: Data point (feature vector), + or an RDD of data points (feature vectors). """ pythonAPI = self._sc._jvm.PythonMLLibAPI() if isinstance(x, RDD): # Bulk prediction if x.count() == 0: - raise RuntimeError("DecisionTreeModel.predict(x) given empty RDD x.") - elementType = type(x.take(1)[0]) - if elementType == LabeledPoint: - x = x.map(lambda x: x.features) - dataBytes = _get_unmangled_double_vector_rdd(x) - jSerializedPreds = pythonAPI.predictDecisionTreeModel(self._java_model, dataBytes._jrdd) - dataBytes.unpersist() + return self._sc.parallelize([]) + dataBytes = _get_unmangled_double_vector_rdd(x, cache=False) + jSerializedPreds = \ + pythonAPI.predictDecisionTreeModel(self._java_model, + dataBytes._jrdd) serializedPreds = RDD(jSerializedPreds, self._sc, NoOpSerializer()) return serializedPreds.map(lambda bytes: _deserialize_double(bytearray(bytes))) else: - if type(x) == LabeledPoint: - x_ = _serialize_double_vector(x.features) - else: - # Assume x is a single data point. - x_ = _serialize_double_vector(x) + # Assume x is a single data point. + x_ = _serialize_double_vector(x) return pythonAPI.predictDecisionTreeModel(self._java_model, x_) def numNodes(self): @@ -83,9 +78,11 @@ def __str__(self): class DecisionTree(object): """ - Learning algorithm for a decision tree model for classification or regression. + Learning algorithm for a decision tree model + for classification or regression. - WARNING: This is an experimental API. It will probably be modified for Spark v1.2. + EXPERIMENTAL: This is an experimental API. + It will probably be modified for Spark v1.2. Example usage: >>> from numpy import array, ndarray @@ -136,12 +133,13 @@ def trainClassifier(data, numClasses, categoricalFeaturesInfo={}, """ Train a DecisionTreeModel for classification. - :param data: RDD of NumPy vectors, one per element, where the first - coordinate is the label and the rest is the feature vector. + :param data: Training data: RDD of LabeledPoint. Labels are integers {0,1,...,numClasses}. :param numClasses: Number of classes for classification. - :param categoricalFeaturesInfo: Map from categorical feature index to number of categories. - Any feature not in this map is treated as continuous. + :param categoricalFeaturesInfo: Map from categorical feature index + to number of categories. + Any feature not in this map + is treated as continuous. :param impurity: Supported values: "entropy" or "gini" :param maxDepth: Max depth of tree. E.g., depth 0 means 1 leaf node. @@ -149,7 +147,8 @@ def trainClassifier(data, numClasses, categoricalFeaturesInfo={}, :param maxBins: Number of bins used for finding splits at each node. :return: DecisionTreeModel """ - return DecisionTree.train(data, "classification", numClasses, categoricalFeaturesInfo, + return DecisionTree.train(data, "classification", numClasses, + categoricalFeaturesInfo, impurity, maxDepth, maxBins) @staticmethod @@ -158,11 +157,12 @@ def trainRegressor(data, categoricalFeaturesInfo={}, """ Train a DecisionTreeModel for regression. - :param data: RDD of NumPy vectors, one per element, where the first - coordinate is the label and the rest is the feature vector. + :param data: Training data: RDD of LabeledPoint. Labels are real numbers. - :param categoricalFeaturesInfo: Map from categorical feature index to number of categories. - Any feature not in this map is treated as continuous. + :param categoricalFeaturesInfo: Map from categorical feature index + to number of categories. + Any feature not in this map + is treated as continuous. :param impurity: Supported values: "variance" :param maxDepth: Max depth of tree. E.g., depth 0 means 1 leaf node. @@ -170,24 +170,29 @@ def trainRegressor(data, categoricalFeaturesInfo={}, :param maxBins: Number of bins used for finding splits at each node. :return: DecisionTreeModel """ - return DecisionTree.train(data, "regression", 0, categoricalFeaturesInfo, + return DecisionTree.train(data, "regression", 0, + categoricalFeaturesInfo, impurity, maxDepth, maxBins) @staticmethod - def train(data, algo, numClasses, categoricalFeaturesInfo, impurity, maxDepth, maxBins=100): + def train(data, algo, numClasses, categoricalFeaturesInfo, + impurity, maxDepth, maxBins=100): """ Train a DecisionTreeModel for classification or regression. - :param data: RDD of NumPy vectors, one per element, where the first - coordinate is the label and the rest is the feature vector. - For classification, labels are integers {0,1,...,numClasses}. + :param data: Training data: RDD of LabeledPoint. + For classification, labels are integers + {0,1,...,numClasses}. For regression, labels are real numbers. :param algo: "classification" or "regression" - :param numClasses: Number of classes for classification. 0 or 1 indicates regression. - :param categoricalFeaturesInfo: Map from categorical feature index to number of categories. - Any feature not in this map is treated as continuous. - :param impurity: For classification: "entropy" or "gini". For regression: "variance". + :param numClasses: Number of classes for classification. + :param categoricalFeaturesInfo: Map from categorical feature index + to number of categories. + Any feature not in this map + is treated as continuous. + :param impurity: For classification: "entropy" or "gini". + For regression: "variance". :param maxDepth: Max depth of tree. E.g., depth 0 means 1 leaf node. Depth 1 means 1 internal node + 2 leaf nodes. @@ -197,7 +202,8 @@ def train(data, algo, numClasses, categoricalFeaturesInfo, impurity, maxDepth, m sc = data.context dataBytes = _get_unmangled_labeled_point_rdd(data) categoricalFeaturesInfoJMap = \ - MapConverter().convert(categoricalFeaturesInfo, sc._gateway._gateway_client) + MapConverter().convert(categoricalFeaturesInfo, + sc._gateway._gateway_client) model = sc._jvm.PythonMLLibAPI().trainDecisionTreeModel( dataBytes._jrdd, algo, numClasses, categoricalFeaturesInfoJMap, diff --git a/python/pyspark/mllib/util.py b/python/pyspark/mllib/util.py index d94900cefdb77..c9302ebe41f40 100644 --- a/python/pyspark/mllib/util.py +++ b/python/pyspark/mllib/util.py @@ -16,6 +16,7 @@ # import numpy as np +import warnings from pyspark.mllib.linalg import Vectors, SparseVector from pyspark.mllib.regression import LabeledPoint @@ -29,9 +30,9 @@ class MLUtils: Helper methods to load, save and pre-process data used in MLlib. """ - @deprecated @staticmethod def _parse_libsvm_line(line, multiclass): + warnings.warn("deprecated", DeprecationWarning) return _parse_libsvm_line(line) @staticmethod @@ -67,9 +68,9 @@ def _convert_labeled_point_to_libsvm(p): " but got " % type(v)) return " ".join(items) - @deprecated @staticmethod def loadLibSVMFile(sc, path, multiclass=False, numFeatures=-1, minPartitions=None): + warnings.warn("deprecated", DeprecationWarning) return loadLibSVMFile(sc, path, numFeatures, minPartitions) @staticmethod diff --git a/python/run-tests b/python/run-tests index 5049e15ce5f8a..48feba2f5bd63 100755 --- a/python/run-tests +++ b/python/run-tests @@ -71,6 +71,7 @@ run_test "pyspark/mllib/random.py" run_test "pyspark/mllib/recommendation.py" run_test "pyspark/mllib/regression.py" run_test "pyspark/mllib/tests.py" +run_test "pyspark/mllib/util.py" if [[ $FAILED == 0 ]]; then echo -en "\033[32m" # Green From affceb96b9725d79568d1b8b8fc9bb53d9de3806 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Fri, 1 Aug 2014 23:27:40 -0700 Subject: [PATCH 20/21] * Fixed bug in doc tests in pyspark/mllib/util.py caused by change in loadLibSVMFile behavior. (It used to threshold labels at 0 to make them 0/1, but it now leaves them as they are.) * Fixed small bug in loadLibSVMFile: If a data file had no features, then loadLibSVMFile would create a single all-zero feature. --- python/pyspark/mllib/util.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/python/pyspark/mllib/util.py b/python/pyspark/mllib/util.py index c9302ebe41f40..639cda6350229 100644 --- a/python/pyspark/mllib/util.py +++ b/python/pyspark/mllib/util.py @@ -107,7 +107,6 @@ def loadLibSVMFile(sc, path, numFeatures=-1, minPartitions=None): >>> tempFile.write("+1 1:1.0 3:2.0 5:3.0\\n-1\\n-1 2:4.0 4:5.0 6:6.0") >>> tempFile.flush() >>> examples = MLUtils.loadLibSVMFile(sc, tempFile.name).collect() - >>> multiclass_examples = MLUtils.loadLibSVMFile(sc, tempFile.name).collect() >>> tempFile.close() >>> type(examples[0]) == LabeledPoint True @@ -116,20 +115,18 @@ def loadLibSVMFile(sc, path, numFeatures=-1, minPartitions=None): >>> type(examples[1]) == LabeledPoint True >>> print examples[1] - (0.0,(6,[],[])) + (-1.0,(6,[],[])) >>> type(examples[2]) == LabeledPoint True >>> print examples[2] - (0.0,(6,[1,3,5],[4.0,5.0,6.0])) - >>> multiclass_examples[1].label - -1.0 + (-1.0,(6,[1,3,5],[4.0,5.0,6.0])) """ lines = sc.textFile(path, minPartitions) parsed = lines.map(lambda l: MLUtils._parse_libsvm_line(l)) if numFeatures <= 0: parsed.cache() - numFeatures = parsed.map(lambda x: 0 if x[1].size == 0 else x[1][-1]).reduce(max) + 1 + numFeatures = parsed.map(lambda x: -1 if x[1].size == 0 else x[1][-1]).reduce(max) + 1 return parsed.map(lambda x: LabeledPoint(x[0], Vectors.sparse(numFeatures, x[1], x[2]))) @staticmethod From 374448874de7a758658e0ac54cf1d578d09e347d Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Sat, 2 Aug 2014 11:53:37 -0700 Subject: [PATCH 21/21] Renamed test tree.py to decision_tree_runner.py Small updates based on github review. --- .../mllib/{tree.py => decision_tree_runner.py} | 12 ++++++++---- python/pyspark/mllib/tests.py | 8 +++----- 2 files changed, 11 insertions(+), 9 deletions(-) rename examples/src/main/python/mllib/{tree.py => decision_tree_runner.py} (95%) diff --git a/examples/src/main/python/mllib/tree.py b/examples/src/main/python/mllib/decision_tree_runner.py similarity index 95% rename from examples/src/main/python/mllib/tree.py rename to examples/src/main/python/mllib/decision_tree_runner.py index e415368e5bd9f..8efadb5223f56 100755 --- a/examples/src/main/python/mllib/tree.py +++ b/examples/src/main/python/mllib/decision_tree_runner.py @@ -37,6 +37,8 @@ def getAccuracy(dtModel, data): predictions = dtModel.predict(data.map(lambda x: x.features)) truth = data.map(lambda p: p.label) trainCorrect = predictions.zip(truth).aggregate(0, seqOp, add) + if data.count() == 0: + return 0 return trainCorrect / (0.0 + data.count()) @@ -49,6 +51,8 @@ def getMSE(dtModel, data): predictions = dtModel.predict(data.map(lambda x: x.features)) truth = data.map(lambda p: p.label) trainMSE = predictions.zip(truth).aggregate(0, seqOp, add) + if data.count() == 0: + return 0 return trainMSE / (0.0 + data.count()) @@ -78,8 +82,8 @@ def reindexClassLabels(data): print >> sys.stderr, \ "Dataset for classification should have at least 2 classes." + \ " The given dataset had only %d classes." % numClasses - exit(-1) - origToNewLabels = dict([(sortedClasses[i], i) for i in range(0,numClasses)]) + exit(1) + origToNewLabels = dict([(sortedClasses[i], i) for i in range(0, numClasses)]) print "numClasses = %d" % numClasses print "Per-class example fractions, counts:" @@ -98,9 +102,9 @@ def reindexClassLabels(data): def usage(): print >> sys.stderr, \ - "Usage: logistic_regression [libsvm format data filepath]\n" + \ + "Usage: decision_tree_runner [libsvm format data filepath]\n" + \ " Note: This only supports binary classification." - exit(-1) + exit(1) if __name__ == "__main__": diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index fedfe4fb71f8b..9d1e5be637a9a 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -277,9 +277,8 @@ def test_classification(self): self.assertTrue(nb_model.predict(features[3]) > 0) categoricalFeaturesInfo = {0: 3} # feature 0 has 3 categories - dt_model = \ - DecisionTree.trainClassifier(rdd, numClasses=2, - categoricalFeaturesInfo=categoricalFeaturesInfo) + dt_model = DecisionTree.trainClassifier(rdd, numClasses=2, + categoricalFeaturesInfo=categoricalFeaturesInfo) self.assertTrue(dt_model.predict(features[0]) <= 0) self.assertTrue(dt_model.predict(features[1]) > 0) self.assertTrue(dt_model.predict(features[2]) <= 0) @@ -317,8 +316,7 @@ def test_regression(self): self.assertTrue(rr_model.predict(features[3]) > 0) categoricalFeaturesInfo = {0: 2} # feature 0 has 2 categories - dt_model = \ - DecisionTree.trainRegressor(rdd, categoricalFeaturesInfo=categoricalFeaturesInfo) + dt_model = DecisionTree.trainRegressor(rdd, categoricalFeaturesInfo=categoricalFeaturesInfo) self.assertTrue(dt_model.predict(features[0]) <= 0) self.assertTrue(dt_model.predict(features[1]) > 0) self.assertTrue(dt_model.predict(features[2]) <= 0)