From 0092c77b647d49865e8474a302e353f13195524e Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Wed, 30 Aug 2017 10:01:55 +0800 Subject: [PATCH 1/3] create pr --- .../main/scala/org/apache/spark/ml/stat/Summarizer.scala | 8 ++++---- .../spark/mllib/stat/MultivariateOnlineSummarizer.scala | 9 ++------- .../scala/org/apache/spark/ml/stat/SummarizerSuite.scala | 8 ++++++++ .../mllib/stat/MultivariateOnlineSummarizerSuite.scala | 8 ++++++++ 4 files changed, 22 insertions(+), 11 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala b/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala index cae41edb7aca8..59bf23b983344 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala @@ -299,11 +299,11 @@ private[ml] object SummaryBuilderImpl extends Logging { val localCurrMin = currMin instance.foreachActive { (index, value) => if (value != 0.0) { - if (localCurrMax != null && localCurrMax(index) < value) { - localCurrMax(index) = value + if (localCurrMax != null) { + localCurrMax(index) = math.max(localCurrMax(index), value) } - if (localCurrMin != null && localCurrMin(index) > value) { - localCurrMin(index) = value + if (localCurrMin != null) { + localCurrMin(index) = math.min(localCurrMin(index), value) } if (localWeightSum != null) { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala index 8121880cfb233..ee733802a97cc 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala @@ -97,13 +97,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S val localCurrMin = currMin instance.foreachActive { (index, value) => if (value != 0.0) { - if (localCurrMax(index) < value) { - localCurrMax(index) = value - } - if (localCurrMin(index) > value) { - localCurrMin(index) = value - } - + localCurrMax(index) = math.max(localCurrMax(index), value) + localCurrMin(index) = math.min(localCurrMin(index), value) val prevMean = localCurrMean(index) val diff = value - prevMean localCurrMean(index) = prevMean + weight * diff / (localWeightSum(index) + weight) diff --git a/mllib/src/test/scala/org/apache/spark/ml/stat/SummarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/stat/SummarizerSuite.scala index 1ea851ef2d676..6bf0c9b52d84c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/stat/SummarizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/stat/SummarizerSuite.scala @@ -501,12 +501,20 @@ class SummarizerSuite extends SparkFunSuite with MLlibTestSparkContext { } summarizer3.add(Vectors.dense(10.0, -10.0), 1e10) + val summarizer4 = new SummarizerBuffer() + summarizer4.add(Vectors.dense(10.0, Double.NaN), 1) + summarizer4.add(Vectors.dense(Double.NaN, Double.NaN), 1) + assert(summarizer1.max ~== Vectors.dense(10.0, 0.0) absTol 1e-14) assert(summarizer1.min ~== Vectors.dense(0.0, -10.0) absTol 1e-14) assert(summarizer2.max ~== Vectors.dense(10.0, 0.0) absTol 1e-14) assert(summarizer2.min ~== Vectors.dense(0.0, -10.0) absTol 1e-14) assert(summarizer3.max ~== Vectors.dense(10.0, 0.0) absTol 1e-14) assert(summarizer3.min ~== Vectors.dense(0.0, -10.0) absTol 1e-14) + assert(summarizer4.max(0).isNaN) + assert(summarizer4.max(1).isNaN) + assert(summarizer4.min(0).isNaN) + assert(summarizer4.min(1).isNaN) } ignore("performance test") { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala index c6466bc918dd0..78e5d59f8c299 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala @@ -263,12 +263,20 @@ class MultivariateOnlineSummarizerSuite extends SparkFunSuite { } summarizer3.add(Vectors.dense(10.0, -10.0), 1e10) + val summarizer4 = new MultivariateOnlineSummarizer() + summarizer4.add(Vectors.dense(10.0, Double.NaN), 1) + summarizer4.add(Vectors.dense(Double.NaN, Double.NaN), 1) + assert(summarizer1.max ~== Vectors.dense(10.0, 0.0) absTol 1e-14) assert(summarizer1.min ~== Vectors.dense(0.0, -10.0) absTol 1e-14) assert(summarizer2.max ~== Vectors.dense(10.0, 0.0) absTol 1e-14) assert(summarizer2.min ~== Vectors.dense(0.0, -10.0) absTol 1e-14) assert(summarizer3.max ~== Vectors.dense(10.0, 0.0) absTol 1e-14) assert(summarizer3.min ~== Vectors.dense(0.0, -10.0) absTol 1e-14) + assert(summarizer4.max(0).isNaN) + assert(summarizer4.max(1).isNaN) + assert(summarizer4.min(0).isNaN) + assert(summarizer4.min(1).isNaN) } test ("test zero variance (SPARK-21818)") { From 8ba9dc4e34074b84dcfcc88e003c7393390db08b Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Tue, 5 Dec 2017 16:38:39 +0800 Subject: [PATCH 2/3] update MinMaxScaler --- .../spark/ml/feature/MinMaxScaler.scala | 57 ++++++++++++++++--- 1 file changed, 49 insertions(+), 8 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala index f648deced54cd..a353e8667dec9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala @@ -25,11 +25,7 @@ import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT} import org.apache.spark.ml.param.{DoubleParam, ParamMap, Params} import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} import org.apache.spark.ml.util._ -import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors} -import org.apache.spark.mllib.linalg.VectorImplicits._ -import org.apache.spark.mllib.stat.Statistics import org.apache.spark.mllib.util.MLUtils -import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{StructField, StructType} @@ -117,11 +113,56 @@ class MinMaxScaler @Since("1.5.0") (@Since("1.5.0") override val uid: String) @Since("2.0.0") override def fit(dataset: Dataset[_]): MinMaxScalerModel = { transformSchema(dataset.schema, logging = true) - val input: RDD[OldVector] = dataset.select($(inputCol)).rdd.map { - case Row(v: Vector) => OldVectors.fromML(v) + + val vectors = dataset.select($(inputCol)).rdd.map { + case Row(v: Vector) => v + } + val numFeatures = vectors.first().size + + val zeroValue = (0L, Array.ofDim[Long](numFeatures), + Array.fill(numFeatures)(Double.MaxValue), Array.fill(numFeatures)(Double.MinValue)) + + val (count, nnz, min, max) = vectors.treeAggregate(zeroValue)( + seqOp = { case ((count, nnz, min, max), vec) => + require(vec.size == numFeatures) + vec.foreachActive { (i, v) => + if (v != 0) { + if (v < min(i)) { + min(i) = v + } + if (v > max(i)) { + max(i) = v + } + nnz(i) += 1 + } + } + (count + 1, nnz, min, max) + + }, combOp = { case ((count1, nnz1, min1, max1), (count2, nnz2, min2, max2)) => + var i = 0 + while (i < numFeatures) { + nnz1(i) += nnz2(i) + min1(i) = math.min(min1(i), min2(i)) + max1(i) = math.max(max1(i), max1(i)) + i += 1 + } + (count1 + count2, nnz1, min1, max1) + }) + + var i = 0 + while (i < numFeatures) { + if (nnz(i) < count) { + if (min(i) > 0) { + min(i) = 0 + } + if (max(i) < 0) { + max(i) = 0 + } + } + i += 1 } - val summary = Statistics.colStats(input) - copyValues(new MinMaxScalerModel(uid, summary.min, summary.max).setParent(this)) + + copyValues(new MinMaxScalerModel(uid, Vectors.dense(min), Vectors.dense(max)).setParent(this)) } @Since("1.5.0") From 80df3c69ac58ae5f4190b0b62dd2bdb1cef15260 Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Tue, 5 Dec 2017 16:40:16 +0800 Subject: [PATCH 3/3] update MinMaxScaler 2 --- .../main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala index a353e8667dec9..a9b3e7df7111b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala @@ -143,7 +143,7 @@ class MinMaxScaler @Since("1.5.0") (@Since("1.5.0") override val uid: String) while (i < numFeatures) { nnz1(i) += nnz2(i) min1(i) = math.min(min1(i), min2(i)) - max1(i) = math.max(max1(i), max1(i)) + max1(i) = math.max(max1(i), max2(i)) i += 1 } (count1 + count2, nnz1, min1, max1)