From 2daba5fe042330ea1060f4f5c48b84d225739b82 Mon Sep 17 00:00:00 2001 From: sethah Date: Thu, 20 Oct 2016 09:54:53 -0700 Subject: [PATCH 1/2] add instr to gbt --- .../apache/spark/ml/classification/GBTClassifier.scala | 10 +++++++++- .../org/apache/spark/ml/regression/GBTRegressor.scala | 10 +++++++++- 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index ba70293273f94..8bffe0cda0327 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -137,9 +137,17 @@ class GBTClassifier @Since("1.4.0") ( } val numFeatures = oldDataset.first().features.size val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Classification) + + val instr = Instrumentation.create(this, oldDataset) + instr.logParams(params: _*) + instr.logNumFeatures(numFeatures) + instr.logNumClasses(2) + val (baseLearners, learnerWeights) = GradientBoostedTrees.run(oldDataset, boostingStrategy, $(seed)) - new GBTClassificationModel(uid, baseLearners, learnerWeights, numFeatures) + val m = new GBTClassificationModel(uid, baseLearners, learnerWeights, numFeatures) + instr.logSuccess(m) + m } @Since("1.4.1") diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index bb01f9d5a364c..642132953a5ce 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -123,9 +123,17 @@ class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String) val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) val numFeatures = oldDataset.first().features.size val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Regression) + + val instr = Instrumentation.create(this, oldDataset) + instr.logParams(params: _*) + instr.logNumFeatures(numFeatures) + instr.logNumClasses(0) + val (baseLearners, learnerWeights) = GradientBoostedTrees.run(oldDataset, boostingStrategy, $(seed)) - new GBTRegressionModel(uid, baseLearners, learnerWeights, numFeatures) + val m = new GBTRegressionModel(uid, baseLearners, learnerWeights, numFeatures) + instr.logSuccess(m) + m } @Since("1.4.0") From b2ad1c8a29187ae765440ae12713017fb5b92bd1 Mon Sep 17 00:00:00 2001 From: sethah Date: Tue, 25 Oct 2016 10:27:41 -0700 Subject: [PATCH 2/2] remove log num classes for regression --- .../main/scala/org/apache/spark/ml/regression/GBTRegressor.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index 642132953a5ce..fa69d60836e68 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -127,7 +127,6 @@ class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String) val instr = Instrumentation.create(this, oldDataset) instr.logParams(params: _*) instr.logNumFeatures(numFeatures) - instr.logNumClasses(0) val (baseLearners, learnerWeights) = GradientBoostedTrees.run(oldDataset, boostingStrategy, $(seed))