From 1e98a4c0a44138ff2f020ed4f9350c8dd968a911 Mon Sep 17 00:00:00 2001 From: AiHe Date: Thu, 7 May 2015 14:19:58 -0700 Subject: [PATCH 1/4] [MLLIB][tree] Add reservoir sample in RandomForest reservoir feature sample by using existing api --- .../scala/org/apache/spark/mllib/tree/RandomForest.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala index 055e60c7d9c95..93a6f571c726b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala @@ -36,6 +36,7 @@ import org.apache.spark.mllib.tree.model._ import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils +import org.apache.spark.util.random.SamplingUtils /** * :: Experimental :: @@ -473,9 +474,8 @@ object RandomForest extends Serializable with Logging { val (treeIndex, node) = nodeQueue.head // Choose subset of features for node (if subsampling). val featureSubset: Option[Array[Int]] = if (metadata.subsamplingFeatures) { - // TODO: Use more efficient subsampling? (use selection-and-rejection or reservoir) - Some(rng.shuffle(Range(0, metadata.numFeatures).toList) - .take(metadata.numFeaturesPerNode).toArray) + Some(SamplingUtils.reservoirSampleAndCount(Range(0, + metadata.numFeatures).iterator, metadata.numFeaturesPerNode)._1) } else { None } From 37459e14cc80ce491ed3947270d13bbed5cba716 Mon Sep 17 00:00:00 2001 From: AiHe Date: Fri, 8 May 2015 16:36:44 -0700 Subject: [PATCH 2/4] set fixed seed --- .../main/scala/org/apache/spark/mllib/tree/RandomForest.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala index 93a6f571c726b..dee132d6ee77a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala @@ -475,7 +475,7 @@ object RandomForest extends Serializable with Logging { // Choose subset of features for node (if subsampling). val featureSubset: Option[Array[Int]] = if (metadata.subsamplingFeatures) { Some(SamplingUtils.reservoirSampleAndCount(Range(0, - metadata.numFeatures).iterator, metadata.numFeaturesPerNode)._1) + metadata.numFeatures).iterator, metadata.numFeaturesPerNode, 41L)._1) } else { None } From 28ffb9ab7f4c17e7ebc5fae9207db0627f19f4a1 Mon Sep 17 00:00:00 2001 From: AiHe Date: Mon, 11 May 2015 23:16:00 -0700 Subject: [PATCH 3/4] set seed as rng.nextLong --- .../main/scala/org/apache/spark/mllib/tree/RandomForest.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala index dee132d6ee77a..b347c450c1aa8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala @@ -475,7 +475,7 @@ object RandomForest extends Serializable with Logging { // Choose subset of features for node (if subsampling). val featureSubset: Option[Array[Int]] = if (metadata.subsamplingFeatures) { Some(SamplingUtils.reservoirSampleAndCount(Range(0, - metadata.numFeatures).iterator, metadata.numFeaturesPerNode, 41L)._1) + metadata.numFeatures).iterator, metadata.numFeaturesPerNode, rng.nextLong)._1) } else { None } From e7a41ac50ae1f2c06c1c37313f1101c2a8b7f699 Mon Sep 17 00:00:00 2001 From: AiHe Date: Fri, 15 May 2015 00:59:38 -0700 Subject: [PATCH 4/4] remove non-robust testing case --- .../scala/org/apache/spark/mllib/tree/RandomForestSuite.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala index ee3bc98486862..4ed66953cb628 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala @@ -196,7 +196,6 @@ class RandomForestSuite extends FunSuite with MLlibTestSparkContext { numClasses = 3, categoricalFeaturesInfo = categoricalFeaturesInfo) val model = RandomForest.trainClassifier(input, strategy, numTrees = 2, featureSubsetStrategy = "sqrt", seed = 12345) - EnsembleTestHelper.validateClassifier(model, arr, 1.0) } test("subsampling rate in RandomForest"){