From fc578dc5cd3c6df66ce77cf7352639b1033f7b56 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Mon, 31 Aug 2015 17:51:42 +0800 Subject: [PATCH 1/3] make checkpointInterval shared params --- .../DecisionTreeClassifier.scala | 1 + .../ml/param/shared/SharedParamsCodeGen.scala | 3 ++- .../spark/ml/param/shared/sharedParams.scala | 8 ++++--- .../org/apache/spark/ml/tree/treeParams.scala | 24 +++---------------- 4 files changed, 11 insertions(+), 25 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index 6f70b96b17ec6..0a75d5d22280f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -19,6 +19,7 @@ package org.apache.spark.ml.classification import org.apache.spark.annotation.Experimental import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.param.shared.HasCheckpointInterval import org.apache.spark.ml.tree.{DecisionTreeModel, DecisionTreeParams, Node, TreeClassifierParams} import org.apache.spark.ml.tree.impl.RandomForest import org.apache.spark.ml.util.{Identifiable, MetadataUtils} diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala index 8c16c6149b40d..84bd5dd795b81 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -56,7 +56,8 @@ private[shared] object SharedParamsCodeGen { ParamDesc[String]("inputCol", "input column name"), ParamDesc[Array[String]]("inputCols", "input column names"), ParamDesc[String]("outputCol", "output column name", Some("uid + \"__output\"")), - ParamDesc[Int]("checkpointInterval", "checkpoint interval (>= 1)", + ParamDesc[Int]("checkpointInterval", "checkpoint interval (>= 1). E.g. 10 means that " + + "the cache will get checkpointed every 10 iterations.", Some("10"), isValid = "ParamValidators.gtEq(1)"), ParamDesc[Boolean]("fitIntercept", "whether to fit an intercept term", Some("true")), ParamDesc[String]("handleInvalid", "how to handle invalid entries. Options are skip (which " + diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index c26768953e3db..f81969df826e8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -218,15 +218,17 @@ private[ml] trait HasOutputCol extends Params { } /** - * Trait for shared param checkpointInterval. + * Trait for shared param checkpointInterval (default: 10). */ private[ml] trait HasCheckpointInterval extends Params { /** - * Param for checkpoint interval (>= 1). + * Param for checkpoint interval (>= 1). E.g. 10 means that the cache will get checkpointed every 10 iterations.. * @group param */ - final val checkpointInterval: IntParam = new IntParam(this, "checkpointInterval", "checkpoint interval (>= 1)", ParamValidators.gtEq(1)) + final val checkpointInterval: IntParam = new IntParam(this, "checkpointInterval", "checkpoint interval (>= 1). E.g. 10 means that the cache will get checkpointed every 10 iterations.", ParamValidators.gtEq(1)) + + setDefault(checkpointInterval, 10) /** @group getParam */ final def getCheckpointInterval: Int = $(checkpointInterval) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala index dbd8d31571d2e..b4b543231ec2e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.tree import org.apache.spark.ml.classification.ClassifierParams import org.apache.spark.ml.PredictorParams import org.apache.spark.ml.param._ -import org.apache.spark.ml.param.shared.{HasMaxIter, HasSeed, HasThresholds} +import org.apache.spark.ml.param.shared.{HasCheckpointInterval, HasMaxIter, HasSeed, HasThresholds} import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, BoostingStrategy => OldBoostingStrategy, Strategy => OldStrategy} import org.apache.spark.mllib.tree.impurity.{Entropy => OldEntropy, Gini => OldGini, Impurity => OldImpurity, Variance => OldVariance} import org.apache.spark.mllib.tree.loss.{Loss => OldLoss} @@ -30,7 +30,7 @@ import org.apache.spark.mllib.tree.loss.{Loss => OldLoss} * * Note: Marked as private and DeveloperApi since this may be made public in the future. */ -private[ml] trait DecisionTreeParams extends PredictorParams { +private[ml] trait DecisionTreeParams extends PredictorParams with HasCheckpointInterval { /** * Maximum depth of the tree (>= 0). @@ -96,23 +96,8 @@ private[ml] trait DecisionTreeParams extends PredictorParams { " algorithm will cache node IDs for each instance. Caching can speed up training of deeper" + " trees.") - /** - * Specifies how often to checkpoint the cached node IDs. - * E.g. 10 means that the cache will get checkpointed every 10 iterations. - * This is only used if cacheNodeIds is true and if the checkpoint directory is set in - * [[org.apache.spark.SparkContext]]. - * Must be >= 1. - * (default = 10) - * @group expertParam - */ - final val checkpointInterval: IntParam = new IntParam(this, "checkpointInterval", "Specifies" + - " how often to checkpoint the cached node IDs. E.g. 10 means that the cache will get" + - " checkpointed every 10 iterations. This is only used if cacheNodeIds is true and if the" + - " checkpoint directory is set in the SparkContext. Must be >= 1.", - ParamValidators.gtEq(1)) - setDefault(maxDepth -> 5, maxBins -> 32, minInstancesPerNode -> 1, minInfoGain -> 0.0, - maxMemoryInMB -> 256, cacheNodeIds -> false, checkpointInterval -> 10) + maxMemoryInMB -> 256, cacheNodeIds -> false) /** @group setParam */ def setMaxDepth(value: Int): this.type = set(maxDepth, value) @@ -153,9 +138,6 @@ private[ml] trait DecisionTreeParams extends PredictorParams { /** @group expertSetParam */ def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) - /** @group expertGetParam */ - final def getCheckpointInterval: Int = $(checkpointInterval) - /** (private[ml]) Create a Strategy instance to use with the old API. */ private[ml] def getOldStrategy( categoricalFeatures: Map[Int, Int], From 85258ddea160d3ae7081f6f9c40324c6f9d5f6ff Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Tue, 1 Sep 2015 11:24:08 +0800 Subject: [PATCH 2/3] remove default value for checkpointInterval in sharedParams --- .../apache/spark/ml/param/shared/SharedParamsCodeGen.scala | 2 +- .../scala/org/apache/spark/ml/param/shared/sharedParams.scala | 4 +--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala index 84bd5dd795b81..e9e99ed1db40e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -57,7 +57,7 @@ private[shared] object SharedParamsCodeGen { ParamDesc[Array[String]]("inputCols", "input column names"), ParamDesc[String]("outputCol", "output column name", Some("uid + \"__output\"")), ParamDesc[Int]("checkpointInterval", "checkpoint interval (>= 1). E.g. 10 means that " + - "the cache will get checkpointed every 10 iterations.", Some("10"), + "the cache will get checkpointed every 10 iterations.", isValid = "ParamValidators.gtEq(1)"), ParamDesc[Boolean]("fitIntercept", "whether to fit an intercept term", Some("true")), ParamDesc[String]("handleInvalid", "how to handle invalid entries. Options are skip (which " + diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index f81969df826e8..30092170863ad 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -218,7 +218,7 @@ private[ml] trait HasOutputCol extends Params { } /** - * Trait for shared param checkpointInterval (default: 10). + * Trait for shared param checkpointInterval. */ private[ml] trait HasCheckpointInterval extends Params { @@ -228,8 +228,6 @@ private[ml] trait HasCheckpointInterval extends Params { */ final val checkpointInterval: IntParam = new IntParam(this, "checkpointInterval", "checkpoint interval (>= 1). E.g. 10 means that the cache will get checkpointed every 10 iterations.", ParamValidators.gtEq(1)) - setDefault(checkpointInterval, 10) - /** @group getParam */ final def getCheckpointInterval: Int = $(checkpointInterval) } From 85ffd3bb4498b6b2469430e8c295de7c777931a3 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Tue, 1 Sep 2015 11:30:34 +0800 Subject: [PATCH 3/3] more annotation for setCheckpointInterval --- .../scala/org/apache/spark/ml/tree/treeParams.scala | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala index b4b543231ec2e..d29f5253c9c3f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala @@ -97,7 +97,7 @@ private[ml] trait DecisionTreeParams extends PredictorParams with HasCheckpointI " trees.") setDefault(maxDepth -> 5, maxBins -> 32, minInstancesPerNode -> 1, minInfoGain -> 0.0, - maxMemoryInMB -> 256, cacheNodeIds -> false) + maxMemoryInMB -> 256, cacheNodeIds -> false, checkpointInterval -> 10) /** @group setParam */ def setMaxDepth(value: Int): this.type = set(maxDepth, value) @@ -135,7 +135,15 @@ private[ml] trait DecisionTreeParams extends PredictorParams with HasCheckpointI /** @group expertGetParam */ final def getCacheNodeIds: Boolean = $(cacheNodeIds) - /** @group expertSetParam */ + /** + * Specifies how often to checkpoint the cached node IDs. + * E.g. 10 means that the cache will get checkpointed every 10 iterations. + * This is only used if cacheNodeIds is true and if the checkpoint directory is set in + * [[org.apache.spark.SparkContext]]. + * Must be >= 1. + * (default = 10) + * @group expertSetParam + */ def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) /** (private[ml]) Create a Strategy instance to use with the old API. */