From 225822fe38762596b8c917a867e5cdbb2d9b4b55 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Fri, 1 Aug 2014 14:50:42 -0700 Subject: [PATCH 1/6] Bug: In DecisionTree, the method sequentialBinSearchForOrderedCategoricalFeatureInClassification() indexed bins from 0 to (math.pow(2, featureCategories.toInt - 1) - 1). This upper bound is the bound for unordered categorical features, not ordered ones. The upper bound should be the arity (i.e., max value) of the feature. Added new test to DecisionTreeSuite to catch this: "regression stump with categorical variables of arity 2" Bug fix: Modified upper bound discussed above. Also: Small improvements to coding style in DecisionTree. --- .../spark/mllib/tree/DecisionTree.scala | 45 +++++++++++-------- .../spark/mllib/tree/DecisionTreeSuite.scala | 29 ++++++++++++ 2 files changed, 56 insertions(+), 18 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 7d123dd6ae996..382e76a9b7cba 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -498,7 +498,7 @@ object DecisionTree extends Serializable with Logging { val bin = binForFeatures(mid) val lowThreshold = bin.lowSplit.threshold val highThreshold = bin.highSplit.threshold - if ((lowThreshold < feature) && (highThreshold >= feature)){ + if ((lowThreshold < feature) && (highThreshold >= feature)) { return mid } else if (lowThreshold >= feature) { @@ -522,28 +522,36 @@ object DecisionTree extends Serializable with Logging { } /** - * Sequential search helper method to find bin for categorical feature. + * Sequential search helper method to find bin for categorical feature + * (for classification and regression). */ - def sequentialBinSearchForOrderedCategoricalFeatureInClassification(): Int = { + def sequentialBinSearchForOrderedCategoricalFeature(): Int = { val featureCategories = strategy.categoricalFeaturesInfo(featureIndex) - val numCategoricalBins = math.pow(2.0, featureCategories - 1).toInt - 1 + val featureValue = labeledPoint.features(featureIndex) var binIndex = 0 - while (binIndex < numCategoricalBins) { + while (binIndex < featureCategories) { val bin = bins(featureIndex)(binIndex) val categories = bin.highSplit.categories - val features = labeledPoint.features - if (categories.contains(features(featureIndex))) { + if (categories.contains(featureValue)) { return binIndex } binIndex += 1 } + if (featureValue < 0 || featureValue >= featureCategories) { + throw new IllegalArgumentException( + s"DecisionTree given invalid data:" + + s" Feature $featureIndex is categorical with values in" + + s" {0,...,${featureCategories - 1}," + + s" but a data point gives it value $featureValue.\n" + + " Bad data point: " + labeledPoint.toString) + } -1 } if (isFeatureContinuous) { // Perform binary search for finding bin for continuous features. val binIndex = binarySearchForBins() - if (binIndex == -1){ + if (binIndex == -1) { throw new UnknownError("no bin was found for continuous variable.") } binIndex @@ -555,10 +563,10 @@ object DecisionTree extends Serializable with Logging { if (isUnorderedFeature) { sequentialBinSearchForUnorderedCategoricalFeatureInClassification() } else { - sequentialBinSearchForOrderedCategoricalFeatureInClassification() + sequentialBinSearchForOrderedCategoricalFeature() } } - if (binIndex == -1){ + if (binIndex == -1) { throw new UnknownError("no bin was found for categorical variable.") } binIndex @@ -642,11 +650,12 @@ object DecisionTree extends Serializable with Logging { val arrShift = 1 + numFeatures * nodeIndex val arrIndex = arrShift + featureIndex // Update the left or right count for one bin. - val aggShift = numClasses * numBins * numFeatures * nodeIndex - val aggIndex - = aggShift + numClasses * featureIndex * numBins + arr(arrIndex).toInt * numClasses - val labelInt = label.toInt - agg(aggIndex + labelInt) = agg(aggIndex + labelInt) + 1 + val aggIndex = + numClasses * numBins * numFeatures * nodeIndex + + numClasses * numBins * featureIndex + + numClasses * arr(arrIndex).toInt + + label.toInt + agg(aggIndex) += 1 } /** @@ -1127,7 +1136,7 @@ object DecisionTree extends Serializable with Logging { val rightNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, numClasses) var featureIndex = 0 while (featureIndex < numFeatures) { - if (isMulticlassClassificationWithCategoricalFeatures){ + if (isMulticlassClassificationWithCategoricalFeatures) { val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty if (isFeatureContinuous) { findAggForOrderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex) @@ -1393,7 +1402,7 @@ object DecisionTree extends Serializable with Logging { // Iterate over all features. var featureIndex = 0 - while (featureIndex < numFeatures){ + while (featureIndex < numFeatures) { // Check whether the feature is continuous. val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty if (isFeatureContinuous) { @@ -1513,7 +1522,7 @@ object DecisionTree extends Serializable with Logging { if (isFeatureContinuous) { // Bins for categorical variables are already assigned. bins(featureIndex)(0) = new Bin(new DummyLowSplit(featureIndex, Continuous), splits(featureIndex)(0), Continuous, Double.MinValue) - for (index <- 1 until numBins - 1){ + for (index <- 1 until numBins - 1) { val bin = new Bin(splits(featureIndex)(index-1), splits(featureIndex)(index), Continuous, Double.MinValue) bins(featureIndex)(index) = bin diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 10462db700628..546a132559326 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -42,6 +42,18 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(accuracy >= requiredAccuracy) } + def validateRegressor( + model: DecisionTreeModel, + input: Seq[LabeledPoint], + requiredMSE: Double) { + val predictions = input.map(x => model.predict(x.features)) + val squaredError = predictions.zip(input).map { case (prediction, expected) => + (prediction - expected.label) * (prediction - expected.label) + }.sum + val mse = squaredError / input.length + assert(mse <= requiredMSE) + } + test("split and bin calculation") { val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1() assert(arr.length === 1000) @@ -454,6 +466,23 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(stats.impurity > 0.2) } + test("regression stump with categorical variables of arity 2") { + val arr = DecisionTreeSuite.generateCategoricalDataPoints() + assert(arr.length === 1000) + val rdd = sc.parallelize(arr) + val strategy = new Strategy( + Regression, + Variance, + maxDepth = 2, + maxBins = 100, + categoricalFeaturesInfo = Map(0 -> 2, 1-> 2)) + + val model = DecisionTree.train(rdd, strategy) + validateRegressor(model, arr, 0.0) + assert(model.numNodes === 3) + assert(model.depth === 1) + } + test("stump with fixed label 0 for Gini") { val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel0() assert(arr.length === 1000) From f1a8283c5cb6a497a9ac60c8ce1859dbe9a051b0 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Fri, 1 Aug 2014 15:56:09 -0700 Subject: [PATCH 2/6] Added old JavaDecisionTreeSuite, to be updated later --- .../mllib/tree/JavaDecisionTreeSuite.java | 91 +++++++++++++++++++ 1 file changed, 91 insertions(+) create mode 100644 mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java diff --git a/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java b/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java new file mode 100644 index 0000000000000..13e86ec9f4161 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.rdd.DatasetInfo; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.tree.configuration.DTClassifierParams; +import org.apache.spark.mllib.tree.model.DecisionTreeClassifierModel; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import java.io.Serializable; +import java.util.List; + +public class JavaDecisionTreeSuite implements Serializable { + private transient JavaSparkContext sc; + + @Before + public void setUp() { + sc = new JavaSparkContext("local", "JavaDecisionTreeSuite"); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + } + + int validatePrediction(List validationData, DecisionTreeClassifierModel model) { + int numAccurate = 0; + for (LabeledPoint point: validationData) { + Double prediction = model.predict(point.features()); + if (prediction == point.label()) { + numAccurate++; + } + } + return numAccurate; + } + + @Test + public void runDTUsingConstructor() { + scala.Tuple2, DatasetInfo> arr_datasetInfo = + DecisionTreeSuite.generateCategoricalDataPointsAsList(); + JavaRDD rdd = sc.parallelize(arr_datasetInfo._1()); + DatasetInfo datasetInfo = arr_datasetInfo._2(); + + DTClassifierParams dtParams = DecisionTreeClassifier.defaultParams(); + dtParams.setMaxBins(200); + dtParams.setImpurity("entropy"); + DecisionTreeClassifier dtLearner = new DecisionTreeClassifier(dtParams); + DecisionTreeClassifierModel model = dtLearner.run(rdd.rdd(), datasetInfo); + + int numAccurate = validatePrediction(arr_datasetInfo._1(), model); + Assert.assertTrue(numAccurate == rdd.count()); + } +/* + @Test + public void runDTUsingStaticMethods() { + scala.Tuple2, DatasetInfo> arr_datasetInfo = + DecisionTreeSuite.generateCategoricalDataPointsAsList(); + JavaRDD rdd = sc.parallelize(arr_datasetInfo._1()); + DatasetInfo datasetInfo = arr_datasetInfo._2(); + + DTClassifierParams dtParams = DecisionTreeClassifier.defaultParams(); + DecisionTreeClassifierModel model = + DecisionTreeClassifier.train(rdd.rdd(), datasetInfo, dtParams); + + int numAccurate = validatePrediction(arr_datasetInfo._1(), model); + Assert.assertTrue(numAccurate == rdd.count()); + } +*/ +} From 320853f464ca8658d7e28a9f39f288da33c88b23 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Fri, 1 Aug 2014 17:40:53 -0700 Subject: [PATCH 3/6] Added JavaDecisionTreeSuite, partly written --- .../mllib/tree/JavaDecisionTreeSuite.java | 35 ++++++++++++++----- .../spark/mllib/tree/DecisionTreeSuite.scala | 6 ++++ 2 files changed, 32 insertions(+), 9 deletions(-) diff --git a/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java b/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java index 13e86ec9f4161..c5fc2f89d9aef 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java @@ -19,10 +19,14 @@ import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.mllib.rdd.DatasetInfo; import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.tree.configuration.DTClassifierParams; -import org.apache.spark.mllib.tree.model.DecisionTreeClassifierModel; +import org.apache.spark.mllib.tree.DecisionTree; +import org.apache.spark.mllib.tree.configuration.Algo; +import org.apache.spark.mllib.tree.configuration.QuantileStrategy; +import org.apache.spark.mllib.tree.configuration.Strategy; +import org.apache.spark.mllib.tree.impurity.Gini; +import org.apache.spark.mllib.tree.impurity.Impurity; +import org.apache.spark.mllib.tree.model.DecisionTreeModel; import org.junit.After; import org.junit.Assert; import org.junit.Before; @@ -45,7 +49,7 @@ public void tearDown() { sc = null; } - int validatePrediction(List validationData, DecisionTreeClassifierModel model) { + int validatePrediction(List validationData, DecisionTreeModel model) { int numAccurate = 0; for (LabeledPoint point: validationData) { Double prediction = model.predict(point.features()); @@ -58,12 +62,25 @@ int validatePrediction(List validationData, DecisionTreeClassifier @Test public void runDTUsingConstructor() { - scala.Tuple2, DatasetInfo> arr_datasetInfo = - DecisionTreeSuite.generateCategoricalDataPointsAsList(); - JavaRDD rdd = sc.parallelize(arr_datasetInfo._1()); - DatasetInfo datasetInfo = arr_datasetInfo._2(); + List arr = DecisionTreeSuite.generateCategoricalDataPointsAsJavaList(); + JavaRDD rdd = sc.parallelize(arr); - DTClassifierParams dtParams = DecisionTreeClassifier.defaultParams(); + int maxDepth = 4; + int numClasses = 2; + + Strategy strategy = new Strategy(Algo.Classification(), Gini(), maxDepth, numClasses, maxBins, QuantileStrategy.Sort(), categoricalFeaturesInfo); + + val algo: Algo, + val impurity: Impurity, + val maxDepth: Int, + val numClassesForClassification: Int = 2, + val maxBins: Int = 100, + val quantileCalculationStrategy: QuantileStrategy = Sort, + val categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](), + val maxMemoryInMB: Int = 128) extends Serializable { + + + DTClassifierParams dtParams = DecisionTreeClassifier.defaultParams(); dtParams.setMaxBins(200); dtParams.setImpurity("entropy"); DecisionTreeClassifier dtLearner = new DecisionTreeClassifier(dtParams); diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 546a132559326..151fb99d139ce 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.mllib.tree +import scala.collection.JavaConversions._ + import org.scalatest.FunSuite import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance} @@ -814,6 +816,10 @@ object DecisionTreeSuite { arr } + def generateCategoricalDataPointsAsJavaList(): java.util.List[LabeledPoint] = { + seqAsJavaList(generateCategoricalDataPoints().toSeq) + } + def generateCategoricalDataPointsForMulticlass(): Array[LabeledPoint] = { val arr = new Array[LabeledPoint](3000) for (i <- 0 until 3000) { From f7b5ca1ed464de5d7d20f4c006621afa8d8b9e56 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Sat, 2 Aug 2014 12:56:47 -0700 Subject: [PATCH 4/6] Improvements to make it easier to run DecisionTree from Java. * DecisionTree: Used new RDD.retag() method to allow passing RDDs from Java. * impurity classes: Added instance() methods to help with Java interface. * Strategy: Added Java-friendly constructor ** Note: I removed quantileCalculationStrategy from the Java-friendly constructor since (a) it is a special class and (b) there is only 1 option currently. I suspect we will redo the API before the other options are included. --- .../spark/mllib/tree/DecisionTree.scala | 8 +-- .../mllib/tree/configuration/Strategy.scala | 29 ++++++++ .../spark/mllib/tree/impurity/Entropy.scala | 7 ++ .../spark/mllib/tree/impurity/Gini.scala | 7 ++ .../spark/mllib/tree/impurity/Variance.scala | 7 ++ .../mllib/tree/JavaDecisionTreeSuite.java | 68 ++++++++++--------- 6 files changed, 89 insertions(+), 37 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 382e76a9b7cba..1d03e6e3b36cf 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -48,12 +48,12 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo def train(input: RDD[LabeledPoint]): DecisionTreeModel = { // Cache input RDD for speedup during multiple passes. - input.cache() + val retaggedInput = input.retag(classOf[LabeledPoint]).cache() logDebug("algo = " + strategy.algo) // Find the splits and the corresponding bins (interval between the splits) using a sample // of the input data. - val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) + val (splits, bins) = DecisionTree.findSplitsBins(retaggedInput, strategy) val numBins = bins(0).length logDebug("numBins = " + numBins) @@ -70,7 +70,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo // dummy value for top node (updated during first split calculation) val nodes = new Array[Node](maxNumNodes) // num features - val numFeatures = input.take(1)(0).features.size + val numFeatures = retaggedInput.take(1)(0).features.size // Calculate level for single group construction @@ -107,7 +107,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo logDebug("#####################################") // Find best split for all nodes at a level. - val splitsStatsForLevel = DecisionTree.findBestSplits(input, parentImpurities, + val splitsStatsForLevel = DecisionTree.findBestSplits(retaggedInput, parentImpurities, strategy, level, filters, splits, bins, maxLevelForSingleGroup) for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index 5c65b537b6867..1a5492e0dfb53 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -17,6 +17,8 @@ package org.apache.spark.mllib.tree.configuration +import scala.collection.JavaConversions._ + import org.apache.spark.annotation.Experimental import org.apache.spark.mllib.tree.impurity.Impurity import org.apache.spark.mllib.tree.configuration.Algo._ @@ -60,4 +62,31 @@ class Strategy ( val isMulticlassWithCategoricalFeatures = isMulticlassClassification && (categoricalFeaturesInfo.size > 0) + /** + * Java-friendly constructor. + * + * @param algo classification or regression + * @param impurity criterion used for information gain calculation + * @param maxDepth Maximum depth of the tree. + * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. + * @param numClassesForClassification number of classes for classification. Default value is 2 + * leads to binary classification + * @param maxBins maximum number of bins used for splitting features + * @param categoricalFeaturesInfo A map storing information about the categorical variables and + * the number of discrete values they take. For example, an entry + * (n -> k) implies the feature n is categorical with k categories + * 0, 1, 2, ... , k-1. It's important to note that features are + * zero-indexed. + */ + def this( + algo: Algo, + impurity: Impurity, + maxDepth: Int, + numClassesForClassification: Int, + maxBins: Int, + categoricalFeaturesInfo: java.util.Map[java.lang.Integer, java.lang.Integer]) { + this(algo, impurity, maxDepth, numClassesForClassification, maxBins, Sort, + categoricalFeaturesInfo.map{ case (a, b) => (a.toInt, b.toInt) }.toMap) + } + } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala index 9297c20596527..96d2471e1f88c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala @@ -66,4 +66,11 @@ object Entropy extends Impurity { @DeveloperApi override def calculate(count: Double, sum: Double, sumSquares: Double): Double = throw new UnsupportedOperationException("Entropy.calculate") + + /** + * Get this impurity instance. + * This is useful for passing impurity parameters to a Strategy in Java. + */ + def instance = this + } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala index 2874bcf496484..d586f449048bb 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala @@ -62,4 +62,11 @@ object Gini extends Impurity { @DeveloperApi override def calculate(count: Double, sum: Double, sumSquares: Double): Double = throw new UnsupportedOperationException("Gini.calculate") + + /** + * Get this impurity instance. + * This is useful for passing impurity parameters to a Strategy in Java. + */ + def instance = this + } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala index 698a1a2a8e899..f7d99a40eb380 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala @@ -53,4 +53,11 @@ object Variance extends Impurity { val squaredLoss = sumSquares - (sum * sum) / count squaredLoss / count } + + /** + * Get this impurity instance. + * This is useful for passing impurity parameters to a Strategy in Java. + */ + def instance = this + } diff --git a/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java b/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java index c5fc2f89d9aef..c39dee1fac172 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java @@ -17,6 +17,12 @@ package org.apache.spark.mllib.tree; +import scala.Int; +import scala.collection.immutable.Map; +import scala.collection.JavaConverters.*; +import scala.Predef; +import scala.Tuple2; + import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.regression.LabeledPoint; @@ -25,6 +31,7 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy; import org.apache.spark.mllib.tree.configuration.Strategy; import org.apache.spark.mllib.tree.impurity.Gini; +import org.apache.spark.mllib.tree.impurity.Gini$; import org.apache.spark.mllib.tree.impurity.Impurity; import org.apache.spark.mllib.tree.model.DecisionTreeModel; import org.junit.After; @@ -33,6 +40,7 @@ import org.junit.Test; import java.io.Serializable; +import java.util.HashMap; import java.util.List; public class JavaDecisionTreeSuite implements Serializable { @@ -50,59 +58,53 @@ public void tearDown() { } int validatePrediction(List validationData, DecisionTreeModel model) { - int numAccurate = 0; + int numCorrect = 0; for (LabeledPoint point: validationData) { Double prediction = model.predict(point.features()); if (prediction == point.label()) { - numAccurate++; + numCorrect++; } } - return numAccurate; + return numCorrect; } @Test public void runDTUsingConstructor() { List arr = DecisionTreeSuite.generateCategoricalDataPointsAsJavaList(); JavaRDD rdd = sc.parallelize(arr); + HashMap categoricalFeaturesInfo = new HashMap(); + categoricalFeaturesInfo.put(1, 2); // feature 1 has 2 categories int maxDepth = 4; int numClasses = 2; + int maxBins = 100; + Strategy strategy = new Strategy(Algo.Classification(), Gini.instance(), maxDepth, numClasses, + maxBins, categoricalFeaturesInfo); - Strategy strategy = new Strategy(Algo.Classification(), Gini(), maxDepth, numClasses, maxBins, QuantileStrategy.Sort(), categoricalFeaturesInfo); - - val algo: Algo, - val impurity: Impurity, - val maxDepth: Int, - val numClassesForClassification: Int = 2, - val maxBins: Int = 100, - val quantileCalculationStrategy: QuantileStrategy = Sort, - val categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](), - val maxMemoryInMB: Int = 128) extends Serializable { - + DecisionTree learner = new DecisionTree(strategy); + DecisionTreeModel model = learner.train(rdd.rdd()); - DTClassifierParams dtParams = DecisionTreeClassifier.defaultParams(); - dtParams.setMaxBins(200); - dtParams.setImpurity("entropy"); - DecisionTreeClassifier dtLearner = new DecisionTreeClassifier(dtParams); - DecisionTreeClassifierModel model = dtLearner.run(rdd.rdd(), datasetInfo); - - int numAccurate = validatePrediction(arr_datasetInfo._1(), model); - Assert.assertTrue(numAccurate == rdd.count()); + int numCorrect = validatePrediction(arr, model); + Assert.assertTrue(numCorrect == rdd.count()); } -/* + @Test public void runDTUsingStaticMethods() { - scala.Tuple2, DatasetInfo> arr_datasetInfo = - DecisionTreeSuite.generateCategoricalDataPointsAsList(); - JavaRDD rdd = sc.parallelize(arr_datasetInfo._1()); - DatasetInfo datasetInfo = arr_datasetInfo._2(); + List arr = DecisionTreeSuite.generateCategoricalDataPointsAsJavaList(); + JavaRDD rdd = sc.parallelize(arr); + HashMap categoricalFeaturesInfo = new HashMap(); + categoricalFeaturesInfo.put(1, 2); // feature 1 has 2 categories + + int maxDepth = 4; + int numClasses = 2; + int maxBins = 100; + Strategy strategy = new Strategy(Algo.Classification(), Gini.instance(), maxDepth, numClasses, + maxBins, categoricalFeaturesInfo); - DTClassifierParams dtParams = DecisionTreeClassifier.defaultParams(); - DecisionTreeClassifierModel model = - DecisionTreeClassifier.train(rdd.rdd(), datasetInfo, dtParams); + DecisionTreeModel model = DecisionTree$.MODULE$.train(rdd.rdd(), strategy); - int numAccurate = validatePrediction(arr_datasetInfo._1(), model); - Assert.assertTrue(numAccurate == rdd.count()); + int numCorrect = validatePrediction(arr, model); + Assert.assertTrue(numCorrect == rdd.count()); } -*/ + } From 519b1b796177d4f2ad3337bffb6efb65fcd9a39c Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Sat, 2 Aug 2014 19:36:53 -0700 Subject: [PATCH 5/6] * Organized imports in JavaDecisionTreeSuite.java * Using JavaConverters instead of JavaConversions in DecisionTreeSuite.scala --- .../mllib/tree/JavaDecisionTreeSuite.java | 24 +++++++------------ .../spark/mllib/tree/DecisionTreeSuite.scala | 4 ++-- 2 files changed, 10 insertions(+), 18 deletions(-) diff --git a/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java b/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java index c39dee1fac172..2c281a1ee7157 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java @@ -17,31 +17,23 @@ package org.apache.spark.mllib.tree; -import scala.Int; -import scala.collection.immutable.Map; -import scala.collection.JavaConverters.*; -import scala.Predef; -import scala.Tuple2; +import java.io.Serializable; +import java.util.HashMap; +import java.util.List; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.tree.DecisionTree; import org.apache.spark.mllib.tree.configuration.Algo; -import org.apache.spark.mllib.tree.configuration.QuantileStrategy; import org.apache.spark.mllib.tree.configuration.Strategy; import org.apache.spark.mllib.tree.impurity.Gini; -import org.apache.spark.mllib.tree.impurity.Gini$; -import org.apache.spark.mllib.tree.impurity.Impurity; import org.apache.spark.mllib.tree.model.DecisionTreeModel; -import org.junit.After; -import org.junit.Assert; -import org.junit.Before; -import org.junit.Test; -import java.io.Serializable; -import java.util.HashMap; -import java.util.List; public class JavaDecisionTreeSuite implements Serializable { private transient JavaSparkContext sc; diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 151fb99d139ce..76fd266c50eed 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.tree -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import org.scalatest.FunSuite @@ -817,7 +817,7 @@ object DecisionTreeSuite { } def generateCategoricalDataPointsAsJavaList(): java.util.List[LabeledPoint] = { - seqAsJavaList(generateCategoricalDataPoints().toSeq) + generateCategoricalDataPoints().toList.asJava } def generateCategoricalDataPointsForMulticlass(): Array[LabeledPoint] = { From 0805dc6ccdb40d73e82a1d6989db3ff62cc805a7 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Sun, 3 Aug 2014 00:53:34 -0700 Subject: [PATCH 6/6] Changed Strategy to use JavaConverters instead of JavaConversions --- .../org/apache/spark/mllib/tree/configuration/Strategy.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index 1a5492e0dfb53..6121000d37f90 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.tree.configuration -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import org.apache.spark.annotation.Experimental import org.apache.spark.mllib.tree.impurity.Impurity @@ -86,7 +86,7 @@ class Strategy ( maxBins: Int, categoricalFeaturesInfo: java.util.Map[java.lang.Integer, java.lang.Integer]) { this(algo, impurity, maxDepth, numClassesForClassification, maxBins, Sort, - categoricalFeaturesInfo.map{ case (a, b) => (a.toInt, b.toInt) }.toMap) + categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap) } }