From c7c6b642dccb3fbd5f82b81cff7f88a776976c7d Mon Sep 17 00:00:00 2001 From: Patrick Pisciuneri Date: Sat, 21 Sep 2019 14:26:54 -0500 Subject: [PATCH] [SPARK-29121][ML][MLLIB] Support for dot product operation on Vector(s) ### What changes were proposed in this pull request? Support for dot product with: - `ml.linalg.Vector` - `ml.linalg.Vectors` - `mllib.linalg.Vector` - `mllib.linalg.Vectors` ### Why are the changes needed? Dot product is useful for feature engineering and scoring. BLAS routines are already there, just a wrapper is needed. ### Does this PR introduce any user-facing change? No user facing changes, just some new functionality. ### How was this patch tested? Tests were written and added to the appropriate `VectorSuites` classes. They can be quickly run with: ``` sbt "mllib-local/testOnly org.apache.spark.ml.linalg.VectorsSuite" sbt "mllib/testOnly org.apache.spark.mllib.linalg.VectorsSuite" ``` Closes #25818 from phpisciuneri/SPARK-29121. Authored-by: Patrick Pisciuneri Signed-off-by: Sean Owen --- .../org/apache/spark/ml/linalg/Vectors.scala | 8 +++++++ .../apache/spark/ml/linalg/VectorsSuite.scala | 23 +++++++++++++++++++ .../apache/spark/mllib/linalg/Vectors.scala | 8 +++++++ .../spark/mllib/linalg/VectorsSuite.scala | 23 +++++++++++++++++++ 4 files changed, 62 insertions(+) diff --git a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala index 6e43d60bd03a3..f437d66cddb54 100644 --- a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala +++ b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala @@ -178,6 +178,14 @@ sealed trait Vector extends Serializable { */ @Since("2.0.0") def argmax: Int + + /** + * Calculate the dot product of this vector with another. + * + * If `size` does not match an [[IllegalArgumentException]] is thrown. + */ + @Since("3.0.0") + def dot(v: Vector): Double = BLAS.dot(this, v) } /** diff --git a/mllib-local/src/test/scala/org/apache/spark/ml/linalg/VectorsSuite.scala b/mllib-local/src/test/scala/org/apache/spark/ml/linalg/VectorsSuite.scala index 0a316f57f811b..c97dc2c3c06f8 100644 --- a/mllib-local/src/test/scala/org/apache/spark/ml/linalg/VectorsSuite.scala +++ b/mllib-local/src/test/scala/org/apache/spark/ml/linalg/VectorsSuite.scala @@ -380,4 +380,27 @@ class VectorsSuite extends SparkMLFunSuite { Vectors.sparse(-1, Array((1, 2.0))) } } + + test("dot product only supports vectors of same size") { + val vSize4 = Vectors.dense(arr) + val vSize1 = Vectors.zeros(1) + intercept[IllegalArgumentException]{ vSize1.dot(vSize4) } + } + + test("dense vector dot product") { + val dv = Vectors.dense(arr) + assert(dv.dot(dv) === 0.26) + } + + test("sparse vector dot product") { + val sv = Vectors.sparse(n, indices, values) + assert(sv.dot(sv) === 0.26) + } + + test("mixed sparse and dense vector dot product") { + val sv = Vectors.sparse(n, indices, values) + val dv = Vectors.dense(arr) + assert(sv.dot(dv) === 0.26) + assert(dv.dot(sv) === 0.26) + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index b754fad0c1796..83a519326df75 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -204,6 +204,14 @@ sealed trait Vector extends Serializable { */ @Since("2.0.0") def asML: newlinalg.Vector + + /** + * Calculate the dot product of this vector with another. + * + * If `size` does not match an [[IllegalArgumentException]] is thrown. + */ + @Since("3.0.0") + def dot(v: Vector): Double = BLAS.dot(this, v) } /** diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala index fee0b02bf8ed8..b2163b518dbd1 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala @@ -510,4 +510,27 @@ class VectorsSuite extends SparkFunSuite with Logging { Vectors.sparse(-1, Array((1, 2.0))) } } + + test("dot product only supports vectors of same size") { + val vSize4 = Vectors.dense(arr) + val vSize1 = Vectors.zeros(1) + intercept[IllegalArgumentException]{ vSize1.dot(vSize4) } + } + + test("dense vector dot product") { + val dv = Vectors.dense(arr) + assert(dv.dot(dv) === 0.26) + } + + test("sparse vector dot product") { + val sv = Vectors.sparse(n, indices, values) + assert(sv.dot(sv) === 0.26) + } + + test("mixed sparse and dense vector dot product") { + val sv = Vectors.sparse(n, indices, values) + val dv = Vectors.dense(arr) + assert(sv.dot(dv) === 0.26) + assert(dv.dot(sv) === 0.26) + } }