From ebe3e74df70eb424aecc3170fc55008cfb6a76ec Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Tue, 28 Oct 2014 22:42:50 -0700 Subject: [PATCH 1/3] First commit --- .../stat/MultivariateOnlineSummarizer.scala | 25 ++++++++++++++++--- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala index 3025d4837cab4..06001a0421e60 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.stat -import breeze.linalg.{DenseVector => BDV} +import breeze.linalg.{DenseVector => BDV, SparseVector => BSV} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.mllib.linalg.{Vectors, Vector} @@ -72,9 +72,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S require(n == sample.size, s"Dimensions mismatch when adding new sample." + s" Expecting $n but got ${sample.size}.") - sample.toBreeze.activeIterator.foreach { - case (_, 0.0) => // Skip explicit zero elements. - case (i, value) => + @inline def update(i: Int, value: Double) = { + if (value != 0.0) { if (currMax(i) < value) { currMax(i) = value } @@ -89,6 +88,24 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S currL1(i) += math.abs(value) nnz(i) += 1.0 + } + } + + sample.toBreeze match { + case dv: BDV[Double] => { + var j = 0 + while (j < dv.length) { + update(j, dv(j)) + j += 1 + } + } + case sv: BSV[Double] => + var j = 0 + while (j < sv.data.length) { + update(sv.index(j), sv.data(j)) + j += 1 + } + case v => throw new IllegalArgumentException("Do not support vector type " + v.getClass) } totalCnt += 1 From 2b5e8828a6db72adee10cfbdc71f07d372f43f90 Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Tue, 28 Oct 2014 22:57:39 -0700 Subject: [PATCH 2/3] small refactoring --- .../stat/MultivariateOnlineSummarizer.scala | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala index 06001a0421e60..cd6079eb49bdd 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala @@ -17,10 +17,10 @@ package org.apache.spark.mllib.stat -import breeze.linalg.{DenseVector => BDV, SparseVector => BSV} +import breeze.linalg.{DenseVector => BDV} import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.mllib.linalg.{Vectors, Vector} +import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors, Vector} /** * :: DeveloperApi :: @@ -91,18 +91,18 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S } } - sample.toBreeze match { - case dv: BDV[Double] => { + sample match { + case dv: DenseVector => { var j = 0 - while (j < dv.length) { - update(j, dv(j)) + while (j < dv.size) { + update(j, dv.values(j)) j += 1 } } - case sv: BSV[Double] => + case sv: SparseVector => var j = 0 - while (j < sv.data.length) { - update(sv.index(j), sv.data(j)) + while (j < sv.size) { + update(sv.indices(j), sv.values(j)) j += 1 } case v => throw new IllegalArgumentException("Do not support vector type " + v.getClass) From b99db6caa0a5f2d6e69d5940b5c37e88914c5e36 Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Wed, 29 Oct 2014 00:25:01 -0700 Subject: [PATCH 3/3] fixed java.lang.ArrayIndexOutOfBoundsException --- .../apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala index cd6079eb49bdd..fab7c4405c65d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala @@ -101,7 +101,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S } case sv: SparseVector => var j = 0 - while (j < sv.size) { + while (j < sv.indices.size) { update(sv.indices(j), sv.values(j)) j += 1 }