Skip to content

Commit

Permalink
[SPARK-29121][ML][MLLIB] Support for dot product operation on Vector(s)
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

Support for dot product with:
- `ml.linalg.Vector`
- `ml.linalg.Vectors`
- `mllib.linalg.Vector`
- `mllib.linalg.Vectors`

### Why are the changes needed?

Dot product is useful for feature engineering and scoring.  BLAS routines are already there, just a wrapper is needed.

### Does this PR introduce any user-facing change?

No user facing changes, just some new functionality.

### How was this patch tested?

Tests were written and added to the appropriate `VectorSuites` classes.  They can be quickly run with:

```
sbt "mllib-local/testOnly org.apache.spark.ml.linalg.VectorsSuite"
sbt "mllib/testOnly org.apache.spark.mllib.linalg.VectorsSuite"
```

Closes #25818 from phpisciuneri/SPARK-29121.

Authored-by: Patrick Pisciuneri <phpisciuneri@gmail.com>
Signed-off-by: Sean Owen <sean.owen@databricks.com>
  • Loading branch information
phpisciuneri authored and srowen committed Sep 21, 2019
1 parent 3e26492 commit c7c6b64
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 0 deletions.
Expand Up @@ -178,6 +178,14 @@ sealed trait Vector extends Serializable {
*/
@Since("2.0.0")
def argmax: Int

/**
* Calculate the dot product of this vector with another.
*
* If `size` does not match an [[IllegalArgumentException]] is thrown.
*/
@Since("3.0.0")
def dot(v: Vector): Double = BLAS.dot(this, v)
}

/**
Expand Down
Expand Up @@ -380,4 +380,27 @@ class VectorsSuite extends SparkMLFunSuite {
Vectors.sparse(-1, Array((1, 2.0)))
}
}

test("dot product only supports vectors of same size") {
val vSize4 = Vectors.dense(arr)
val vSize1 = Vectors.zeros(1)
intercept[IllegalArgumentException]{ vSize1.dot(vSize4) }
}

test("dense vector dot product") {
val dv = Vectors.dense(arr)
assert(dv.dot(dv) === 0.26)
}

test("sparse vector dot product") {
val sv = Vectors.sparse(n, indices, values)
assert(sv.dot(sv) === 0.26)
}

test("mixed sparse and dense vector dot product") {
val sv = Vectors.sparse(n, indices, values)
val dv = Vectors.dense(arr)
assert(sv.dot(dv) === 0.26)
assert(dv.dot(sv) === 0.26)
}
}
Expand Up @@ -204,6 +204,14 @@ sealed trait Vector extends Serializable {
*/
@Since("2.0.0")
def asML: newlinalg.Vector

/**
* Calculate the dot product of this vector with another.
*
* If `size` does not match an [[IllegalArgumentException]] is thrown.
*/
@Since("3.0.0")
def dot(v: Vector): Double = BLAS.dot(this, v)
}

/**
Expand Down
Expand Up @@ -510,4 +510,27 @@ class VectorsSuite extends SparkFunSuite with Logging {
Vectors.sparse(-1, Array((1, 2.0)))
}
}

test("dot product only supports vectors of same size") {
val vSize4 = Vectors.dense(arr)
val vSize1 = Vectors.zeros(1)
intercept[IllegalArgumentException]{ vSize1.dot(vSize4) }
}

test("dense vector dot product") {
val dv = Vectors.dense(arr)
assert(dv.dot(dv) === 0.26)
}

test("sparse vector dot product") {
val sv = Vectors.sparse(n, indices, values)
assert(sv.dot(sv) === 0.26)
}

test("mixed sparse and dense vector dot product") {
val sv = Vectors.sparse(n, indices, values)
val dv = Vectors.dense(arr)
assert(sv.dot(dv) === 0.26)
assert(dv.dot(sv) === 0.26)
}
}

0 comments on commit c7c6b64

Please sign in to comment.