From 08fcb7c76eef39a3b9075d788365cb11f68c9c04 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Wed, 31 Aug 2016 18:14:57 -0700 Subject: [PATCH] update. --- .../spark/ml/classification/LogisticRegression.scala | 12 ++++++++++-- .../apache/spark/ml/param/shared/sharedParams.scala | 4 ++++ .../spark/ml/regression/LinearRegression.scala | 11 +++++++++-- 3 files changed, 23 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 757d52052d87f..64dee37438923 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -295,6 +295,13 @@ class LogisticRegression @Since("1.2.0") ( instr.logParams(regParam, elasticNetParam, standardization, threshold, maxIter, tol, fitIntercept) + val autoAggregationDepth = + if ($(aggregationDepth) > 0) $(aggregationDepth) + else getAggregationDepthByFormula( + instances.context.getConf.getSizeAsBytes("spark.driver.memory", "1g"), + dataset.select(col($(featuresCol))).first().getAs[Vector](0).size, + instances.getNumPartitions + ) val (summarizer, labelSummarizer) = { val seqOp = (c: (MultivariateOnlineSummarizer, MultiClassSummarizer), instance: Instance) => @@ -306,7 +313,7 @@ class LogisticRegression @Since("1.2.0") ( instances.treeAggregate( new MultivariateOnlineSummarizer, new MultiClassSummarizer - )(seqOp, combOp, $(aggregationDepth)) + )(seqOp, combOp, autoAggregationDepth) } val histogram = labelSummarizer.histogram @@ -370,7 +377,8 @@ class LogisticRegression @Since("1.2.0") ( val bcFeaturesStd = instances.context.broadcast(featuresStd) val costFun = new LogisticCostFun(instances, numClasses, $(fitIntercept), - $(standardization), bcFeaturesStd, regParamL2, multinomial = false, $(aggregationDepth)) + $(standardization), bcFeaturesStd, regParamL2, multinomial = false, + autoAggregationDepth) val optimizer = if ($(elasticNetParam) == 0.0 || $(regParam) == 0.0) { new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol)) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index 9125d9e19bf09..59348c6428350 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -405,5 +405,9 @@ private[ml] trait HasAggregationDepth extends Params { /** @group expertGetParam */ final def getAggregationDepth: Int = $(aggregationDepth) + + def getAggregationDepthByFormula(driverMemory: Long, dimension: Int, partitionNum: Int): Int = { + if (dimension > 200) 2 else 1 + } } // scalastyle:on diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 7fddfd9b10f84..0b4b4213d10ef 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -231,6 +231,12 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK) + val autoAggregationDepth = + if ($(aggregationDepth) > 0) $(aggregationDepth) + else getAggregationDepthByFormula( + instances.context.getConf.getSizeAsBytes("spark.driver.memory", "1g"), + numFeatures, instances.getNumPartitions + ) val (featuresSummarizer, ySummarizer) = { val seqOp = (c: (MultivariateOnlineSummarizer, MultivariateOnlineSummarizer), instance: Instance) => @@ -243,7 +249,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String instances.treeAggregate( new MultivariateOnlineSummarizer, new MultivariateOnlineSummarizer - )(seqOp, combOp, $(aggregationDepth)) + )(seqOp, combOp, autoAggregationDepth) } val yMean = ySummarizer.mean(0) @@ -309,7 +315,8 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String val effectiveL2RegParam = (1.0 - $(elasticNetParam)) * effectiveRegParam val costFun = new LeastSquaresCostFun(instances, yStd, yMean, $(fitIntercept), - $(standardization), bcFeaturesStd, bcFeaturesMean, effectiveL2RegParam, $(aggregationDepth)) + $(standardization), bcFeaturesStd, bcFeaturesMean, effectiveL2RegParam, + autoAggregationDepth) val optimizer = if ($(elasticNetParam) == 0.0 || effectiveRegParam == 0.0) { new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol))