From 20fb05854db4c31d4980a69a747fe675129af209 Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Thu, 11 May 2017 18:07:37 +0800 Subject: [PATCH 1/7] create pr --- .../spark/mllib/stat/MultivariateOnlineSummarizer.scala | 6 ++++-- .../mllib/stat/MultivariateOnlineSummarizerSuite.scala | 9 +++++++++ 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala index 7dc0c459ec032..04b3c396d368c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala @@ -249,7 +249,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S var i = 0 while (i < n) { - if ((nnz(i) < totalCnt) && (currMax(i) < 0.0)) currMax(i) = 0.0 + if (nnz(i) == 0L) currMax(i) = Double.NaN + else if ((nnz(i) < totalCnt) && (currMax(i) < 0.0)) currMax(i) = 0.0 i += 1 } Vectors.dense(currMax) @@ -265,7 +266,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S var i = 0 while (i < n) { - if ((nnz(i) < totalCnt) && (currMin(i) > 0.0)) currMin(i) = 0.0 + if (nnz(i) == 0L) currMax(i) = Double.NaN + else if ((nnz(i) < totalCnt) && (currMin(i) > 0.0)) currMin(i) = 0.0 i += 1 } Vectors.dense(currMin) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala index 797e84fcc7377..c8fbf6213d4e0 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala @@ -270,4 +270,13 @@ class MultivariateOnlineSummarizerSuite extends SparkFunSuite { assert(summarizer3.max ~== Vectors.dense(10.0, 0.0) absTol 1e-14) assert(summarizer3.min ~== Vectors.dense(0.0, -10.0) absTol 1e-14) } + + test("test min/max with identical NaN feature") { + val summarizer = new MultivariateOnlineSummarizer() + .add(Vectors.dense(Double.NaN, -10.0)) + .add(Vectors.dense(Double.NaN, 0.0)) + + assert(summarizer.min(0).isNaN) + assert(summarizer.max(0).isNaN) + } } From c3c36ad46ae6fd90afaff0a4e3a31c58ad44280a Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Thu, 11 May 2017 18:11:34 +0800 Subject: [PATCH 2/7] fix a nit --- .../apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala index 04b3c396d368c..2ccd98ebcde9a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala @@ -266,7 +266,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S var i = 0 while (i < n) { - if (nnz(i) == 0L) currMax(i) = Double.NaN + if (nnz(i) == 0L) currMin(i) = Double.NaN else if ((nnz(i) < totalCnt) && (currMin(i) > 0.0)) currMin(i) = 0.0 i += 1 } From ab4c4b196301ce302237648d6981ba3d24b2b112 Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Thu, 11 May 2017 18:17:56 +0800 Subject: [PATCH 3/7] update condition --- .../spark/mllib/stat/MultivariateOnlineSummarizer.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala index 2ccd98ebcde9a..85f6727485e66 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala @@ -249,7 +249,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S var i = 0 while (i < n) { - if (nnz(i) == 0L) currMax(i) = Double.NaN + if (currMax(i) < currMin(i)) currMax(i) = Double.NaN else if ((nnz(i) < totalCnt) && (currMax(i) < 0.0)) currMax(i) = 0.0 i += 1 } @@ -266,7 +266,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S var i = 0 while (i < n) { - if (nnz(i) == 0L) currMin(i) = Double.NaN + if (currMax(i) < currMin(i)) currMin(i) = Double.NaN else if ((nnz(i) < totalCnt) && (currMin(i) > 0.0)) currMin(i) = 0.0 i += 1 } From 002964b9636f9c91e48588a1d3bb0283122f0263 Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Thu, 11 May 2017 19:11:48 +0800 Subject: [PATCH 4/7] update cond --- .../mllib/stat/MultivariateOnlineSummarizer.scala | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala index 85f6727485e66..c55220a2a59a7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala @@ -249,8 +249,11 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S var i = 0 while (i < n) { - if (currMax(i) < currMin(i)) currMax(i) = Double.NaN - else if ((nnz(i) < totalCnt) && (currMax(i) < 0.0)) currMax(i) = 0.0 + if (nnz(i) < totalCnt) { + if (currMax(i) < 0.0) currMax(i) = 0.0 + } else if (currMax(i) < currMin(i)) { + currMax(i) = Double.NaN + } i += 1 } Vectors.dense(currMax) @@ -266,8 +269,11 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S var i = 0 while (i < n) { - if (currMax(i) < currMin(i)) currMin(i) = Double.NaN - else if ((nnz(i) < totalCnt) && (currMin(i) > 0.0)) currMin(i) = 0.0 + if (nnz(i) < totalCnt) { + if (currMin(i) > 0.0) currMin(i) = 0.0 + } else if (currMax(i) < currMin(i)) { + currMin(i) = Double.NaN + } i += 1 } Vectors.dense(currMin) From 491daa1fb8adf73bbfd2365beab7344860eea38e Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Thu, 11 May 2017 19:19:13 +0800 Subject: [PATCH 5/7] update cond --- .../apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala index c55220a2a59a7..6113afe31771c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala @@ -252,6 +252,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S if (nnz(i) < totalCnt) { if (currMax(i) < 0.0) currMax(i) = 0.0 } else if (currMax(i) < currMin(i)) { + currMin(i) = Double.NaN currMax(i) = Double.NaN } i += 1 @@ -273,6 +274,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S if (currMin(i) > 0.0) currMin(i) = 0.0 } else if (currMax(i) < currMin(i)) { currMin(i) = Double.NaN + currMax(i) = Double.NaN } i += 1 } From f7eaddf0f1a381bb24c457e8b6495cd615e3c795 Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Mon, 21 Aug 2017 11:25:25 +0800 Subject: [PATCH 6/7] update --- .../stat/MultivariateOnlineSummarizer.scala | 22 ++++--------------- 1 file changed, 4 insertions(+), 18 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala index 6113afe31771c..2dd9924cdf664 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala @@ -97,12 +97,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S val localCurrMin = currMin instance.foreachActive { (index, value) => if (value != 0.0) { - if (localCurrMax(index) < value) { - localCurrMax(index) = value - } - if (localCurrMin(index) > value) { - localCurrMin(index) = value - } + localCurrMax(index) = math.max(localCurrMax(index), value) + localCurrMin(index) = math.min(localCurrMin(index), value) val prevMean = localCurrMean(index) val diff = value - prevMean @@ -249,12 +245,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S var i = 0 while (i < n) { - if (nnz(i) < totalCnt) { - if (currMax(i) < 0.0) currMax(i) = 0.0 - } else if (currMax(i) < currMin(i)) { - currMin(i) = Double.NaN - currMax(i) = Double.NaN - } + if ((nnz(i) < totalCnt) && (currMax(i) < 0.0)) currMax(i) = 0.0 i += 1 } Vectors.dense(currMax) @@ -270,12 +261,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S var i = 0 while (i < n) { - if (nnz(i) < totalCnt) { - if (currMin(i) > 0.0) currMin(i) = 0.0 - } else if (currMax(i) < currMin(i)) { - currMin(i) = Double.NaN - currMax(i) = Double.NaN - } + if ((nnz(i) < totalCnt) && (currMin(i) > 0.0)) currMin(i) = 0.0 i += 1 } Vectors.dense(currMin) From 44d117ad4ccad698eb331e3d9ac535bd9a438af0 Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Mon, 21 Aug 2017 11:30:12 +0800 Subject: [PATCH 7/7] update test --- .../stat/MultivariateOnlineSummarizerSuite.scala | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala index c8fbf6213d4e0..08dc898a2022d 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala @@ -271,12 +271,17 @@ class MultivariateOnlineSummarizerSuite extends SparkFunSuite { assert(summarizer3.min ~== Vectors.dense(0.0, -10.0) absTol 1e-14) } - test("test min/max with identical NaN feature") { + test("test min/max with NaN feature") { val summarizer = new MultivariateOnlineSummarizer() .add(Vectors.dense(Double.NaN, -10.0)) .add(Vectors.dense(Double.NaN, 0.0)) - - assert(summarizer.min(0).isNaN) - assert(summarizer.max(0).isNaN) + .add(Vectors.dense(Double.NaN, Double.NaN)) + + val minVec = summarizer.min + val maxVec = summarizer.max + assert(minVec(0).isNaN) + assert(minVec(1).isNaN) + assert(maxVec(0).isNaN) + assert(maxVec(1).isNaN) } }